From b51b4f0291c17df62393715543ac9a5603eff453 Mon Sep 17 00:00:00 2001 From: chenyx113 Date: Thu, 23 Jan 2025 22:07:31 +0800 Subject: [PATCH] add onnx-subgraph baseline code --- tools/onnx-subgraph/.Readme.md.swp | Bin 0 -> 4096 bytes tools/onnx-subgraph/CMakeLists.txt | 81 + tools/onnx-subgraph/Readme.md | 85 + tools/onnx-subgraph/config-sample-1.json | 10 + tools/onnx-subgraph/config-sample-2.json | 15 + tools/onnx-subgraph/config.json | 13 + tools/onnx-subgraph/extract_onnx.py | 21 + tools/onnx-subgraph/extract_onnx_lib.py | 204 + tools/onnx-subgraph/include/device.h | 159 + tools/onnx-subgraph/include/graph.h | 165 + tools/onnx-subgraph/include/json-forwards.h | 265 + tools/onnx-subgraph/include/json.h | 1996 +++++++ tools/onnx-subgraph/include/partition.h | 56 + tools/onnx-subgraph/model_inference.py | 323 ++ .../model_inference_multiple_output.py | 327 ++ tools/onnx-subgraph/onnx.proto | 871 +++ tools/onnx-subgraph/onnx_subgraph_ut.py | 62 + tools/onnx-subgraph/quant.py | 425 ++ .../onnx-subgraph/single_vs_multiple_onnx.py | 106 + tools/onnx-subgraph/src/lib/device.cpp | 82 + tools/onnx-subgraph/src/lib/graph.cpp | 213 + tools/onnx-subgraph/src/lib/jsoncpp.cpp | 4951 +++++++++++++++++ tools/onnx-subgraph/src/lib/partition.cpp | 2727 +++++++++ tools/onnx-subgraph/src/lib/structures.cpp | 69 + tools/onnx-subgraph/src/main.cpp | 62 + tools/onnx-subgraph/test_model_download.sh | 16 + 26 files changed, 13304 insertions(+) create mode 100644 tools/onnx-subgraph/.Readme.md.swp create mode 100644 tools/onnx-subgraph/CMakeLists.txt create mode 100644 tools/onnx-subgraph/Readme.md create mode 100644 tools/onnx-subgraph/config-sample-1.json create mode 100644 tools/onnx-subgraph/config-sample-2.json create mode 100644 tools/onnx-subgraph/config.json create mode 100644 tools/onnx-subgraph/extract_onnx.py create mode 100644 tools/onnx-subgraph/extract_onnx_lib.py create mode 100644 tools/onnx-subgraph/include/device.h create mode 100644 tools/onnx-subgraph/include/graph.h create mode 100644 tools/onnx-subgraph/include/json-forwards.h create mode 100644 tools/onnx-subgraph/include/json.h create mode 100644 tools/onnx-subgraph/include/partition.h create mode 100644 tools/onnx-subgraph/model_inference.py create mode 100644 tools/onnx-subgraph/model_inference_multiple_output.py create mode 100644 tools/onnx-subgraph/onnx.proto create mode 100644 tools/onnx-subgraph/onnx_subgraph_ut.py create mode 100644 tools/onnx-subgraph/quant.py create mode 100644 tools/onnx-subgraph/single_vs_multiple_onnx.py create mode 100644 tools/onnx-subgraph/src/lib/device.cpp create mode 100644 tools/onnx-subgraph/src/lib/graph.cpp create mode 100644 tools/onnx-subgraph/src/lib/jsoncpp.cpp create mode 100644 tools/onnx-subgraph/src/lib/partition.cpp create mode 100644 tools/onnx-subgraph/src/lib/structures.cpp create mode 100644 tools/onnx-subgraph/src/main.cpp create mode 100644 tools/onnx-subgraph/test_model_download.sh diff --git a/tools/onnx-subgraph/.Readme.md.swp b/tools/onnx-subgraph/.Readme.md.swp new file mode 100644 index 0000000000000000000000000000000000000000..b4a498073005fff174a6f37f6b1fe747fc4650ec GIT binary patch literal 4096 zcmYc?2=nw+u+%eT00IF9hGpqp>7t1-OiBq149OX(d6g9)3H$(}PPZ(-xTGi_kJ37b zLj9D?;%r0xl+@IMoYcgkyv)3GegEJveSbe!{i4)j{rtSV3f=20.04 + 2. GCC >= 9.4.0 + 3. cmake >= 3.10 + 4. python >= 3.8 + 5. apt-get install libprotobuf-dev protobuf-compiler + +## Python packages dependence + onnx 1.16.0 + onnxruntime 1.18.1 + onnxsim 0.4.36 + torch 2.3.1 + scikit-image + scikit-learn + pandas + tqdm + +## building the onnx-subgraph + 1. cd onnx-subgraph + 2. mkdir build & cd build + 3. cmake .. & make + 4. we can get following output at ./build + ├── onnx-subgraph + └── scripts + ├── config.json + ├── config-sample-1.json + ├── config-sample-2.json + ├── extract_onnx_lib.py + ├── extract_onnx.py + ├── model_inference_multiple_output.py + ├── model_inference.py + ├── onnx_subgraph_ut.py + ├── quant.py + ├── single_vs_multiple_onnx.py + └── test_model_download.sh +# How to use the onnx-subgraph +## Pre-steps +### Download the test AI models + 1. bash scripts/test_model_download.sh, then "resnet-test.onnx" will be got in ./build + 2. you can change to any other onnx files as your needs, or edit the download link in "scripts/test_model_download.sh" +### Prepare the config.json + 1. edit the config.json + . you can edit operators in "NPU_supported_ops" and "CPU_supported_ops"; + . you can edit performance data in "performance_data" as the real HW status, + . you can edit "max_subgraph_size" in case of "NPU_supported_ops" is [] + 2. you can also check more examples in "config-sample-1.json" and "config-sample-2.json" + + +## Parse the onnx model + ./onnx-subgraph --onnx=resnet-test.onnx + after parsing done, subgraphs_ios.txt will be generated at current path + +## Split the onnx model to subgraphs + 1. edit the config path and model file path at ./scripts/extract_onnx.py + e.g.: extract_onnx_lib.split_onnx_ios('./subgraphs_ios.txt','./resnet-test.onnx') + 2. python scripts/extract_onnx.py, after extraction done, the subgraphs will be saved at './subgraphs' + subgraphs + ├── CPU + │   ├── CPUsubgraph0.onnx + │   └── CPUsubgraph1.onnx + └── NPU + ├── NPUsubgraph0.onnx + └── NPUsubgraph1.onnx + +### Verify the subgraphs inference with original model file + 1. edit the model path, subgraph path and config path in ./scripts/single_vs_multiple_onnx.py + single_onnx_model_path = './resnet-test.onnx' + model_path = './subgraphs/' + subgraphsiostxt_path = './subgraphs_ios.txt' + 2. edit the input shape and name of onnx model in ./scripts/single_vs_multiple_onnx.py + default_input_data = { + "x": np.random.rand(1, 3, 256, 256).astype(np.float32), + } + 3. compare the MSE of original inference result and subgraphs inference result + python ./scripts/single_vs_multiple_onnx.py + output: + Single model inference completed! + Multiple subgraph inference completed! + Comparing inference results between single ONNX model and multiple subgraphs... + Output '316' MSE: 5.125894080395578e-14 diff --git a/tools/onnx-subgraph/config-sample-1.json b/tools/onnx-subgraph/config-sample-1.json new file mode 100644 index 00000000000..3e083ca5b64 --- /dev/null +++ b/tools/onnx-subgraph/config-sample-1.json @@ -0,0 +1,10 @@ +{ + "NPU_supported_ops": [], + "CPU_supported_ops": ["Sub", "Pow", "ReduceMean", "Add", "Sqrt", "Div","Transpose", "Gather", "MatMul", "Mul", "Softmax", "Erf", "Gemm", "Conv", "Reshape", + "Sin", "Where", "ConstantOfShape", "Cast", "Sigmoid", "Cos", "Expand", "Slice", "Unsqueeze"], + "performance_data": [], + "hardware_limits": { + "max_subgraph_size": 10240.0, + "max_subgraphs": 5 + } +} diff --git a/tools/onnx-subgraph/config-sample-2.json b/tools/onnx-subgraph/config-sample-2.json new file mode 100644 index 00000000000..02e840a723b --- /dev/null +++ b/tools/onnx-subgraph/config-sample-2.json @@ -0,0 +1,15 @@ +{ + "NPU_supported_ops": ["Conv", "Reshape", "Transpose", "Add", "ReduceMean", "Sub", "Div", "Mul", "Sigmoid","MatMul"], + "CPU_supported_ops": ["Sub", "Pow", "ReduceMean", "Add", "Sqrt", "Div","Transpose", "Gather", "MatMul", "Mul", "Softmax", "Erf", "Gemm", "Conv", "Reshape", + "Sin", "Where", "ConstantOfShape", "Cast", "Sigmoid", "Cos", "Expand", "Slice", "Unsqueeze"], + "performance_data": [ + {"name":"Conv","CPU_time": 0.1, "NPU_time": 0.05}, + {"name":"Mul", "CPU_time": 0.15, "NPU_time": 0.07} + {"name":"Add", "CPU_time": 0.15, "NPU_time": 0.07} + {"name":"Sub", "CPU_time": 0.15, "NPU_time": 0.07} + ], + "hardware_limits": { + "max_subgraph_size": 60024.0, + "max_subgraphs": 5 + } +} diff --git a/tools/onnx-subgraph/config.json b/tools/onnx-subgraph/config.json new file mode 100644 index 00000000000..6d0b7ce5ace --- /dev/null +++ b/tools/onnx-subgraph/config.json @@ -0,0 +1,13 @@ +{ + "NPU_supported_ops": ["Conv", "Reshape", "Transpose", "Add", "ReduceMean", "Sub", "Div", "Mul", "Sigmoid","MatMul"], + "CPU_supported_ops": ["Sub", "Pow", "ReduceMean", "Add", "Sqrt", "Div","Transpose", "Gather", "MatMul", "Mul", "Softmax", "Erf", "Gemm", "Conv", "Reshape", + "Sin", "Where", "ConstantOfShape", "Cast", "Sigmoid", "Cos", "Expand", "Slice", "Unsqueeze"], + "performance_data": [ + {"name":"Conv","CPU_time": 0.1, "NPU_time": 0.05}, + {"name":"Mul", "CPU_time": 0.15, "NPU_time": 0.07} + ], + "hardware_limits": { + "max_subgraph_size": 60024.0, + "max_subgraphs": 5 + } +} diff --git a/tools/onnx-subgraph/extract_onnx.py b/tools/onnx-subgraph/extract_onnx.py new file mode 100644 index 00000000000..cbcc6a8f6bc --- /dev/null +++ b/tools/onnx-subgraph/extract_onnx.py @@ -0,0 +1,21 @@ +# Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import extract_onnx_lib +import torch +import onnx +import re +print("python executed") +extract_onnx_lib.split_onnx_ios('./subgraphs_ios.txt','./resnet-test.onnx') + diff --git a/tools/onnx-subgraph/extract_onnx_lib.py b/tools/onnx-subgraph/extract_onnx_lib.py new file mode 100644 index 00000000000..d1de0baca90 --- /dev/null +++ b/tools/onnx-subgraph/extract_onnx_lib.py @@ -0,0 +1,204 @@ +# Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import onnx +import re +import os +def splitinstruction(instr): + iolist=re.split('--input-name \"|\" --output-name \"|\" --input-shape \"',instr) + del iolist[0] + del iolist[-1] + in_=iolist[0].split(';') + out_=iolist[1].split(';') + return in_,out_ +def splitsubgraph_ios(iofile): + iolist=re.split('--input-name |;--output-name ',iofile) + in_=iolist[1].split(';') + out_=iolist[2].split(';') + del out_[-1] + type=iolist[0].split('subgraph')[0] + return in_,out_,type + +def split_onnx(instrfile,type): + print("module found") + f1=open(instrfile,"r") + lines=f1.readlines() + count=0 + for line in lines: + input_names, output_names = splitinstruction(line) + input_path ='net/diffusion_model_fp32_with_shape.onnx' + output_path ='diffusion_model_fp32_subgraphs_'+type+'/'+type+'subgraph'+str(count)+'.onnx' + count=count+1 + if((input_names!=['']) and (output_names!=[''])): + onnx.utils.extract_model(input_path, output_path, input_names, output_names) + f1.close() + +def split_onnx_ios(instrfile, input_path ='net/generation_model_simplify.onnx', out_folder = 'subgraphs/'): + if not os.path.exists(input_path): + print(input_path +" not exist") + return + + model = onnx.load(input_path) + onnx.checker.check_model(input_path) + for output in model.graph.output: + model.graph.value_info.append(output) + onnx.save(model, input_path) + f1=open(instrfile,"r") + lines=f1.readlines() + cpu_count = 0 + npu_count = 0 + count=0 + if not os.path.exists(out_folder): + os.makedirs(out_folder) + for line in lines: + input_names, output_names,type = splitsubgraph_ios(line) + if(type=='CPU'): + count=cpu_count + cpu_count=cpu_count+1 + else: + count=npu_count + npu_count=npu_count+1 + output_path_folder =out_folder+type+'/' + if not os.path.exists(output_path_folder): + os.makedirs(output_path_folder) + output_path = output_path_folder+type+'subgraph'+str(count)+'.onnx' + if((input_names!=['']) and (output_names!=[''])): + onnx.utils.extract_model(input_path, output_path, input_names, output_names) + print("succeed",count) + count = count+1 + f1.close() + +def rename_node_io(file_path): + model = onnx.load(file_path) + graph = model.graph + for inputs in graph.input : + inputs.name = re.sub(r'[/.]','',inputs.name) + for outputs in graph.output : + outputs.name = re.sub(r'[/.]','',outputs.name) + for value_infos in graph.value_info : + value_infos.name = re.sub(r'[/.]','',value_infos.name) + for initializers in graph.initializer : + initializers.name = re.sub(r'[/.]','',initializers.name) + for node in graph.node: + node.name = re.sub(r'[/.]','',node.name) + for i in range(len(node.input)): + node.input[i] = re.sub(r'[/.]','',node.input[i]) + for i in range(len(node.output)): + node.output[i] = re.sub(r'[/.]','',node.output[i]) + return model + +def rename_subgraph_node_ios(in_file_path,out_file_path): + file_names = os.listdir(in_file_path) + for filename in file_names: + filename_=in_file_path+'/'+filename + model=rename_node_io(filename_) + output_file_path = out_file_path+'/'+filename + onnx.save(model, output_file_path) + print(f'Modified model saved to {output_file_path}') + + +def print_model(file_path): + model = onnx.load(file_path) + graph = model.graph + size=0 + for node in graph.node: + size=size+1 + print(size) + +def sort(ifile_path,ofile_path): + finished_flag = 0 + sort_count = 0 + f1=open(ifile_path,"r") + lines=f1.readlines() + graphs_inputs = {} + graphs_outputs = {} + order_Subgraphs = {} + issort_Subgraphs = {} + TYPE = {} + index = 0 + for line in lines: + input_names, output_names,type = splitsubgraph_ios(line) + graphs_inputs[index] = input_names + graphs_outputs[index] = output_names + TYPE[index] = type + index = index + 1 + graph_num = index + f1.close() + while finished_flag == 0: + finished_flag = 1 + if(sort_count) == 0: + for i in range(graph_num): + find_flag = 0 + for g_input in graphs_inputs[i]: + for j in range(graph_num): + if g_input in graphs_outputs[j]: + find_flag = 1 + break + if find_flag == 1: + break + if find_flag == 0: + order_Subgraphs[i] = 0 + issort_Subgraphs[i] = 1 + else: + order_Subgraphs[i] = 1 + issort_Subgraphs[i] = 0 + finished_flag = 0 + else: + for i in range(graph_num): + find_flag = 0 + if issort_Subgraphs[i] == 1: + continue + for g_input in graphs_inputs[i]: + for j in range(graph_num): + if g_input in graphs_outputs[j]: + if issort_Subgraphs[j] == 0: + find_flag = 1 + break + if find_flag == 1: + break + if find_flag == 0: + order_Subgraphs[i] = sort_count + issort_Subgraphs[i] = 1 + else: + order_Subgraphs[i] = sort_count + 1 + issort_Subgraphs[i] = 0 + finished_flag = 0 + if i == graph_num - 1: + for j in range(graph_num): + if order_Subgraphs[j]==sort_count: + issort_Subgraphs[j] = 1 + print(order_Subgraphs) + print(issort_Subgraphs) + sort_count = sort_count + 1 + f2 = open(ofile_path,"w") + count_cpu = 0 + count_npu = 0 + for i in range(graph_num): + content = "" + if TYPE[i] == 'CPU': + content = "CPUsubgraph" + str(count_cpu) + ": order" + str(order_Subgraphs[i]) + "--input-name " + count_cpu = count_cpu + 1 + if TYPE[i] == 'NPU': + content = "NPUsubgraph" + str(count_npu) + ": order" + str(order_Subgraphs[i]) + "--input-name " + count_npu = count_npu + 1 + for graph_input in graphs_inputs[i]: + content = content + graph_input + ";" + content = content + "--output-name " + for graph_output in graphs_outputs[i]: + content = content + graph_output + ";" + content = content + "\n" + print(content) + f2.write(content) + f2.close() diff --git a/tools/onnx-subgraph/include/device.h b/tools/onnx-subgraph/include/device.h new file mode 100644 index 00000000000..e42f3bc8cc3 --- /dev/null +++ b/tools/onnx-subgraph/include/device.h @@ -0,0 +1,159 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DEVICE_H +#define DEVICE_H + +#include +#include +#include +#include +#include "onnx.pb.h" +#include "graph.h" +#include "json.h" +enum class DeviceType { Target_NPU }; + +class Device { +private: + std::string onnxFile; +public: + Device(/* args */) { + NPUPreferOp = {}; + CPUSupportOp = {}; + NPUSupportOp = {}; + //NPUPreferOp = {"Sub", "Pow", "ReduceMean", "Add", "Sqrt", "Div","Transpose","Gemm","MatMul"}; + // NPUPreferOp = {"Sub", "Pow", "ReduceMean", "Add", "Sqrt", "Div","Transpose", "Gather", "MatMul", "Mul", "Softmax", "Erf", "Gemm", "Conv", "Reshape", + // "Sin", "Where", "ConstantOfShape", "Cast", "Sigmoid", "Cos", "Expand", "Slice", "Unsqueeze","LayerNormalization","Concat","Shape","Squeeze","Mod","Pad","Range","Tile","Equal","Less","InstanceNormalization","Resize","Split","Clip","BatchNormalization","Identity"}; + // NPUSupportOp = {"Sub", "Pow", "ReduceMean", "Add", "Sqrt", "Div","Transpose", "Gather", "MatMul", "Mul", "Softmax", "Erf", "Gemm", "Conv", "Reshape", + // "Sin", "Where", "ConstantOfShape", "Cast", "Sigmoid", "Cos", "Expand", "Slice", "Unsqueeze","LayerNormalization","Concat","Shape","Squeeze","Mod","Pad","Range","Tile","Equal","Less","InstanceNormalization","Resize","Split","Clip","BatchNormalization","Identity"}; + max_subgraph_size = 0; + } + ~Device() {} + std::vector NPUPreferOp; + std::vector CPUSupportOp; + std::vector NPUSupportOp; + float max_subgraph_size; + DeviceType getType() { + return DeviceType::Target_NPU; + } + std::vector> getCPUStructure() { + return { + {"Concat"}, + {"Sub", "Pow", "ReduceMean", "Add", "Sqrt", "Div"}, + {"Transpose", "Gather", "Gather", "Gather", "Transpose", "MatMul", "Mul", "Softmax", "MatMul"} + }; + } + std::vector> getNPUStructure() { + return { + {"Reshape","Transpose","Reshape"}, + {"Reshape","Sigmoid","Mul","Transpose","Conv","Add","Transpose"}, + {"Reshape","Transpose","Conv","Transpose","Reshape"}, + {"Reshape","Conv","Transpose"}, + {"Reshape","Add","Add","Reshape","Transpose","Conv","Add"}, + {"Conv"} + }; + } + std::vector getNPUSupportOp() { + return NPUSupportOp; + } + std::vector getCPUSupportOp() { + return CPUSupportOp; + } + + std::vector getNPUPreferOp() { + return NPUPreferOp; + } + /** + * @brief Generate cut instructions for subgraphs based on the given device type. + * + * @param [in] Subgraphs A reference to a vector of ONNX GraphProto objects representing subgraphs. + * @param [in] device A string indicating the device type (e.g., "npu" or "c920"). + * @param [in] subgraphs_inputs A reference to a vector of unordered sets containing input information for subgraphs. + * @param [in] subgraphs_outputs A reference to a vector of unordered sets containing output information for subgraphs. + * + * @pre The function assumes that the `Subgraphs`, `subgraphs_inputs`, and `subgraphs_outputs` vectors are properly initialized and have the same size. + * @post A file named ` CutInstruction.txt` is created or overwritten with the generated cut instructions. + * @exception If the output file cannot be opened, an error message is printed, and the program exits. + * + * @return None + */ + void GenerateCutInstruction(std::vector& Subgraphs, std::string device, + std::vector> &subgraphs_inputs, std::vector> &subgraphs_outputs); + /** + * @brief Reads and parses a JSON file containing device information. + * + * This function reads a JSON file from the specified path, parses it, and extracts relevant device information. + * It updates global variables with hardware limits, preferred NPU operations, and supported operations for both NPU and CPU. + * + * @param json_path The file path to the JSON file containing device information. + */ + void GetDeviceJson(std::string json_path) + { + Json::Reader reader; + Json::Value root; + + // Open the JSON file in binary mode + std::ifstream in(json_path, std::ios::binary); + if (!in.is_open()) + { + std::cout << "Error opening file\n"; + return; + } + if(reader.parse(in, root)) + { + // Extract and set the maximum subgraph size from hardware limits + float max_subgraph_size_json = root["hardware_limits"]["max_subgraph_size"].asFloat(); + max_subgraph_size = max_subgraph_size_json; + // Iterate through performance data to identify operations where NPU outperforms CPU + for (unsigned int i = 0; i < root["performance_data"].size(); i++) + { + if(root["performance_data"][i]["CPU_time"].asFloat() > root["performance_data"][i]["NPU_time"].asFloat()) + { + NPUPreferOp.push_back(root["performance_data"][i]["name"].asString()); + } + + } + // Iterate through and store supported NPU operations + for(int i = 0; i < int(root["NPU_supported_ops"].size()); i++) + { + if(std::find(NPUSupportOp.begin(), NPUSupportOp.end(), root["NPU_supported_ops"][i].asString())== NPUSupportOp.end()) + { + NPUSupportOp.push_back(root["NPU_supported_ops"][i].asString()); + } + } + // Iterate through and store supported CPU operations + for(int i = 0; i < int(root["CPU_supported_ops"].size()); i++) + { + if(std::find(CPUSupportOp.begin(), CPUSupportOp.end(), root["CPU_supported_ops"][i].asString())== CPUSupportOp.end()) + { + CPUSupportOp.push_back(root["CPU_supported_ops"][i].asString()); + } + } + } + in.close(); + } + void updateOnnxFile(std::string &path) { + onnxFile = path; + } + + std::string getOnnxFile() { + return onnxFile; + } + +}; + + +#endif diff --git a/tools/onnx-subgraph/include/graph.h b/tools/onnx-subgraph/include/graph.h new file mode 100644 index 00000000000..04880e90ccb --- /dev/null +++ b/tools/onnx-subgraph/include/graph.h @@ -0,0 +1,165 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GRAPH_H +#define GRAPH_H + +#include "onnx.pb.h" +#include +#include +#include +#include +// save the size of each node's inputs and outputs +struct NodeIOSize { + std::vector> inputSizes; + std::vector> outputSizes; +}; + +struct NodeTensor { + std::string name; + std::vector shape; + + // Default constructor + NodeTensor() = default; + + // Constructor with parameters + NodeTensor(const std::string& n, const std::vector& s) : name(n), shape(s) {} + + // Equality comparison operator + bool operator==(const NodeTensor& other) const { + return name == other.name && shape == other.shape; + } +}; + +namespace std { + template <> + struct hash { + size_t operator()(const NodeTensor& tensor) const { + size_t hashValue = hash()(tensor.name); + for (auto& val : tensor.shape) { + hashValue ^= hash()(val) + 0x9e3779b9 + (hashValue << 6) + (hashValue >> 2); + } + return hashValue; + } + }; +} +/** +* @brief Extracts the names and shapes of initializers from the ONNX graph. +* +* @param [in] graph The ONNX graph from which to extract initializers. +* @pre The ONNX graph should be valid and contain initializers. +* @post The names and shapes of the initializers are stored in an unordered set of NodeTensor objects. +* @exception None +* @return An unordered set of NodeTensor objects containing the names and shapes of the initializers. +*/ +std::unordered_set getInitializer(const onnx::GraphProto& graph); +/** +* @brief Extracts the names and shapes of inputs, outputs, and value_info from the ONNX graph. +* +* @param [in] graph The ONNX graph from which to extract inputs, outputs, and value_info. +* @pre The ONNX graph should be valid and contain inputs, outputs, and value_info. +* @post The names and shapes of the inputs, outputs, and value_info are stored in an unordered set of NodeTensor objects. +* @exception None +* @return An unordered set of NodeTensor objects containing the names and shapes of the inputs, outputs, and value_info. +*/ +std::unordered_set getIOvalue(const onnx::GraphProto& graph); +/** +* @brief Determines the input tensors of the graph that are not produced by any node in the graph. +* +* @param [in] g The ONNX GraphProto object representing the graph. +* @param [in] initializerNames A set of NodeTensor objects representing the initializers in the graph. +* @param [out] graphInputs A set of NodeTensor objects representing the input tensors of the graph. +* @pre The GraphProto object g should be valid and contain nodes with proper input and output lists. +* @post The graphInputs set will be populated with NodeTensor objects that are inputs to the graph. +* @exception None +* @return None +*/ +void determineGraphInput(const onnx::GraphProto& g, const std::unordered_set& initializerNames, + std::unordered_set &graphInputs); +/** +* @brief Determines the output tensors of the graph that are either outputs of the original graph or are used as inputs in other parts of the graph. +* +* @param [in] originalGraph The original ONNX GraphProto object representing the graph. +* @param [in] g The ONNX GraphProto object representing the graph to analyze. +* @param [in] allgraphInputs_1 A vector of sets of NodeTensor objects representing the first set of inputs to the graph. +* @param [in] allgraphInputs_2 A vector of sets of NodeTensor objects representing the second set of inputs to the graph. +* @param [out] graphOutputs A set of NodeTensor objects representing the output tensors of the graph. +* @pre The GraphProto objects originalGraph and g should be valid and contain nodes with proper input and output lists. +* @post The graphOutputs set will be populated with NodeTensor objects that are outputs of the graph. +* @exception None +* @return None +*/ +void determineGraphOutput(const onnx::GraphProto& originalGraph, const onnx::GraphProto& g, std::vector> &allgraphInputs_1, + std::vector> &allgraphInputs_2, std::unordered_set &graphOutputs); +/** +* @brief Finds the name of the node that produces a specified output tensor in the given ONNX graph. +* +* @param [in] g The ONNX GraphProto object representing the graph. +* @param [in] outputTensorName The name of the output tensor to find the producing node for. +* @pre The GraphProto object g should be valid and contain nodes with proper input and output lists. +* @post None +* @exception None +* @return The name of the node that produces the specified output tensor, or an empty string if no such node is found. +*/ +std::string findInputNode(const onnx::GraphProto &g, const std::string& outputTensorName); +/** +* @brief Collects the names of all nodes in the given ONNX graph. +* +* @param [in] graph The ONNX GraphProto object representing the graph. +* @pre The GraphProto object graph should be valid and contain nodes with proper names. +* @post None +* @exception None +* @return An unordered set containing the names of all nodes in the graph. +*/ +std::unordered_set collectNodeNames(const onnx::GraphProto& graph); +/** +* @brief Merges nodes from the source graph into the target graph. +* +* @param [in,out] targetGraph The ONNX GraphProto object to which nodes will be added. +* @param [in] sourceGraph The ONNX GraphProto object from which nodes will be copied. +* @pre Both GraphProto objects should be valid. +* @post Nodes from sourceGraph are added to targetGraph. +* @exception Exits the program with an error message if the number of nodes in targetGraph does not match the expected size after merging. +* @return None +*/ +void mergeGraphs(onnx::GraphProto& targetGraph, onnx::GraphProto& sourceGraph); + +class Graph { +private: + /* data */ +public: + Graph() {} + ~Graph() {} + /** + * @brief Loads an ONNX model from a file and returns the graph contained within. + * + * @param [in] path The file path to the ONNX model. + * @pre The file specified by path should exist and be a valid ONNX model. + * @post The ONNX model is parsed and its graph is returned. + * @exception Exits the program with an error message if the file cannot be opened. + * @return The ONNX GraphProto object representing the graph from the model. + */ + onnx::GraphProto GetGraphFromOnnx(std::string &path); + +}; +struct graph_adjacency_node +{ + std::vector output_node_index; + int rank; + std::string name; + int index; +}; +#endif diff --git a/tools/onnx-subgraph/include/json-forwards.h b/tools/onnx-subgraph/include/json-forwards.h new file mode 100644 index 00000000000..61d08880068 --- /dev/null +++ b/tools/onnx-subgraph/include/json-forwards.h @@ -0,0 +1,265 @@ +/// Json-cpp amalgated forward header (http://jsoncpp.sourceforge.net/). +/// It is intended to be used with #include "json/json-forwards.h" +/// This header provides forward declaration for all JsonCpp types. + +// ////////////////////////////////////////////////////////////////////// +// Beginning of content of file: LICENSE +// ////////////////////////////////////////////////////////////////////// + +/* +The JsonCpp library's source code, including accompanying documentation, +tests and demonstration applications, are licensed under the following +conditions... + +The author (Baptiste Lepilleur) explicitly disclaims copyright in all +jurisdictions which recognize such a disclaimer. In such jurisdictions, +this software is released into the Public Domain. + +In jurisdictions which do not recognize Public Domain property (e.g. Germany as of +2010), this software is Copyright (c) 2007-2010 by Baptiste Lepilleur, and is +released under the terms of the MIT License (see below). + +In jurisdictions which recognize Public Domain property, the user of this +software may choose to accept it either as 1) Public Domain, 2) under the +conditions of the MIT License (see below), or 3) under the terms of dual +Public Domain/MIT License conditions described here, as they choose. + +The MIT License is about as close to Public Domain as a license can get, and is +described in clear, concise terms at: + + http://en.wikipedia.org/wiki/MIT_License + +The full text of the MIT License follows: + +======================================================================== +Copyright (c) 2007-2010 Baptiste Lepilleur + +Permission is hereby granted, free of charge, to any person +obtaining a copy of this software and associated documentation +files (the "Software"), to deal in the Software without +restriction, including without limitation the rights to use, copy, +modify, merge, publish, distribute, sublicense, and/or sell copies +of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS +BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN +ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +======================================================================== +(END LICENSE TEXT) + +The MIT license is compatible with both the GPL and commercial +software, affording one all of the rights of Public Domain with the +minor nuisance of being required to keep the above copyright notice +and license text in the source code. Note also that by accepting the +Public Domain "license" you can re-license your copy using whatever +license you like. + +*/ + +// ////////////////////////////////////////////////////////////////////// +// End of content of file: LICENSE +// ////////////////////////////////////////////////////////////////////// + + + + + +#ifndef JSON_FORWARD_AMALGATED_H_INCLUDED +# define JSON_FORWARD_AMALGATED_H_INCLUDED +/// If defined, indicates that the source file is amalgated +/// to prevent private header inclusion. +#define JSON_IS_AMALGAMATION + +// ////////////////////////////////////////////////////////////////////// +// Beginning of content of file: include/json/config.h +// ////////////////////////////////////////////////////////////////////// + +// Copyright 2007-2010 Baptiste Lepilleur +// Distributed under MIT license, or public domain if desired and +// recognized in your jurisdiction. +// See file LICENSE for detail or copy at http://jsoncpp.sourceforge.net/LICENSE + +#ifndef JSON_CONFIG_H_INCLUDED +#define JSON_CONFIG_H_INCLUDED + +/// If defined, indicates that json library is embedded in CppTL library. +//# define JSON_IN_CPPTL 1 + +/// If defined, indicates that json may leverage CppTL library +//# define JSON_USE_CPPTL 1 +/// If defined, indicates that cpptl vector based map should be used instead of +/// std::map +/// as Value container. +//# define JSON_USE_CPPTL_SMALLMAP 1 + +// If non-zero, the library uses exceptions to report bad input instead of C +// assertion macros. The default is to use exceptions. +#ifndef JSON_USE_EXCEPTION +#define JSON_USE_EXCEPTION 1 +#endif + +/// If defined, indicates that the source file is amalgated +/// to prevent private header inclusion. +/// Remarks: it is automatically defined in the generated amalgated header. +// #define JSON_IS_AMALGAMATION + +#ifdef JSON_IN_CPPTL +#include +#ifndef JSON_USE_CPPTL +#define JSON_USE_CPPTL 1 +#endif +#endif + +#ifdef JSON_IN_CPPTL +#define JSON_API CPPTL_API +#elif defined(JSON_DLL_BUILD) +#if defined(_MSC_VER) +#define JSON_API __declspec(dllexport) +#define JSONCPP_DISABLE_DLL_INTERFACE_WARNING +#endif // if defined(_MSC_VER) +#elif defined(JSON_DLL) +#if defined(_MSC_VER) +#define JSON_API __declspec(dllimport) +#define JSONCPP_DISABLE_DLL_INTERFACE_WARNING +#endif // if defined(_MSC_VER) +#endif // ifdef JSON_IN_CPPTL +#if !defined(JSON_API) +#define JSON_API +#endif + +#if !defined(JSON_HAS_UNIQUE_PTR) +#if __cplusplus >= 201103L +#define JSON_HAS_UNIQUE_PTR (1) +#elif _MSC_VER >= 1600 +#define JSON_HAS_UNIQUE_PTR (1) +#else +#define JSON_HAS_UNIQUE_PTR (0) +#endif +#endif + +// If JSON_NO_INT64 is defined, then Json only support C++ "int" type for +// integer +// Storages, and 64 bits integer support is disabled. +// #define JSON_NO_INT64 1 + +#if defined(_MSC_VER) && _MSC_VER <= 1200 // MSVC 6 +// Microsoft Visual Studio 6 only support conversion from __int64 to double +// (no conversion from unsigned __int64). +#define JSON_USE_INT64_DOUBLE_CONVERSION 1 +// Disable warning 4786 for VS6 caused by STL (identifier was truncated to '255' +// characters in the debug information) +// All projects I've ever seen with VS6 were using this globally (not bothering +// with pragma push/pop). +#pragma warning(disable : 4786) +#endif // if defined(_MSC_VER) && _MSC_VER < 1200 // MSVC 6 + +#if defined(_MSC_VER) && _MSC_VER >= 1500 // MSVC 2008 +/// Indicates that the following function is deprecated. +#define JSONCPP_DEPRECATED(message) __declspec(deprecated(message)) +#elif defined(__clang__) && defined(__has_feature) +#if __has_feature(attribute_deprecated_with_message) +#define JSONCPP_DEPRECATED(message) __attribute__ ((deprecated(message))) +#endif +#elif defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 5)) +#define JSONCPP_DEPRECATED(message) __attribute__ ((deprecated(message))) +#elif defined(__GNUC__) && (__GNUC__ > 3 || (__GNUC__ == 3 && __GNUC_MINOR__ >= 1)) +#define JSONCPP_DEPRECATED(message) __attribute__((__deprecated__)) +#endif + +#if !defined(JSONCPP_DEPRECATED) +#define JSONCPP_DEPRECATED(message) +#endif // if !defined(JSONCPP_DEPRECATED) + +namespace Json { +typedef int Int; +typedef unsigned int UInt; +#if defined(JSON_NO_INT64) +typedef int LargestInt; +typedef unsigned int LargestUInt; +#undef JSON_HAS_INT64 +#else // if defined(JSON_NO_INT64) +// For Microsoft Visual use specific types as long long is not supported +#if defined(_MSC_VER) // Microsoft Visual Studio +typedef __int64 Int64; +typedef unsigned __int64 UInt64; +#else // if defined(_MSC_VER) // Other platforms, use long long +typedef long long int Int64; +typedef unsigned long long int UInt64; +#endif // if defined(_MSC_VER) +typedef Int64 LargestInt; +typedef UInt64 LargestUInt; +#define JSON_HAS_INT64 +#endif // if defined(JSON_NO_INT64) +} // end namespace Json + +#endif // JSON_CONFIG_H_INCLUDED + +// ////////////////////////////////////////////////////////////////////// +// End of content of file: include/json/config.h +// ////////////////////////////////////////////////////////////////////// + + + + + + +// ////////////////////////////////////////////////////////////////////// +// Beginning of content of file: include/json/forwards.h +// ////////////////////////////////////////////////////////////////////// + +// Copyright 2007-2010 Baptiste Lepilleur +// Distributed under MIT license, or public domain if desired and +// recognized in your jurisdiction. +// See file LICENSE for detail or copy at http://jsoncpp.sourceforge.net/LICENSE + +#ifndef JSON_FORWARDS_H_INCLUDED +#define JSON_FORWARDS_H_INCLUDED + +#if !defined(JSON_IS_AMALGAMATION) +#include "config.h" +#endif // if !defined(JSON_IS_AMALGAMATION) + +namespace Json { + +// writer.h +class FastWriter; +class StyledWriter; + +// reader.h +class Reader; + +// features.h +class Features; + +// value.h +typedef unsigned int ArrayIndex; +class StaticString; +class Path; +class PathArgument; +class Value; +class ValueIteratorBase; +class ValueIterator; +class ValueConstIterator; + +} // namespace Json + +#endif // JSON_FORWARDS_H_INCLUDED + +// ////////////////////////////////////////////////////////////////////// +// End of content of file: include/json/forwards.h +// ////////////////////////////////////////////////////////////////////// + + + + + +#endif //ifndef JSON_FORWARD_AMALGATED_H_INCLUDED diff --git a/tools/onnx-subgraph/include/json.h b/tools/onnx-subgraph/include/json.h new file mode 100644 index 00000000000..67b523a64b7 --- /dev/null +++ b/tools/onnx-subgraph/include/json.h @@ -0,0 +1,1996 @@ +/// Json-cpp amalgated header (http://jsoncpp.sourceforge.net/). +/// It is intended to be used with #include "json/json.h" + +// ////////////////////////////////////////////////////////////////////// +// Beginning of content of file: LICENSE +// ////////////////////////////////////////////////////////////////////// + +/* +The JsonCpp library's source code, including accompanying documentation, +tests and demonstration applications, are licensed under the following +conditions... + +The author (Baptiste Lepilleur) explicitly disclaims copyright in all +jurisdictions which recognize such a disclaimer. In such jurisdictions, +this software is released into the Public Domain. + +In jurisdictions which do not recognize Public Domain property (e.g. Germany as of +2010), this software is Copyright (c) 2007-2010 by Baptiste Lepilleur, and is +released under the terms of the MIT License (see below). + +In jurisdictions which recognize Public Domain property, the user of this +software may choose to accept it either as 1) Public Domain, 2) under the +conditions of the MIT License (see below), or 3) under the terms of dual +Public Domain/MIT License conditions described here, as they choose. + +The MIT License is about as close to Public Domain as a license can get, and is +described in clear, concise terms at: + + http://en.wikipedia.org/wiki/MIT_License + +The full text of the MIT License follows: + +======================================================================== +Copyright (c) 2007-2010 Baptiste Lepilleur + +Permission is hereby granted, free of charge, to any person +obtaining a copy of this software and associated documentation +files (the "Software"), to deal in the Software without +restriction, including without limitation the rights to use, copy, +modify, merge, publish, distribute, sublicense, and/or sell copies +of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS +BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN +ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +======================================================================== +(END LICENSE TEXT) + +The MIT license is compatible with both the GPL and commercial +software, affording one all of the rights of Public Domain with the +minor nuisance of being required to keep the above copyright notice +and license text in the source code. Note also that by accepting the +Public Domain "license" you can re-license your copy using whatever +license you like. + +*/ + +// ////////////////////////////////////////////////////////////////////// +// End of content of file: LICENSE +// ////////////////////////////////////////////////////////////////////// + + + + + +#ifndef JSON_AMALGATED_H_INCLUDED +# define JSON_AMALGATED_H_INCLUDED +/// If defined, indicates that the source file is amalgated +/// to prevent private header inclusion. +#define JSON_IS_AMALGAMATION + +// ////////////////////////////////////////////////////////////////////// +// Beginning of content of file: include/json/version.h +// ////////////////////////////////////////////////////////////////////// + +// DO NOT EDIT. This file (and "version") is generated by CMake. +// Run CMake configure step to update it. +#ifndef JSON_VERSION_H_INCLUDED +# define JSON_VERSION_H_INCLUDED + +# define JSONCPP_VERSION_STRING "0.10.7" +# define JSONCPP_VERSION_MAJOR 0 +# define JSONCPP_VERSION_MINOR 10 +# define JSONCPP_VERSION_PATCH 7 +# define JSONCPP_VERSION_QUALIFIER +# define JSONCPP_VERSION_HEXA ((JSONCPP_VERSION_MAJOR << 24) | (JSONCPP_VERSION_MINOR << 16) | (JSONCPP_VERSION_PATCH << 8)) + +#endif // JSON_VERSION_H_INCLUDED + +// ////////////////////////////////////////////////////////////////////// +// End of content of file: include/json/version.h +// ////////////////////////////////////////////////////////////////////// + + + + + + +// ////////////////////////////////////////////////////////////////////// +// Beginning of content of file: include/json/config.h +// ////////////////////////////////////////////////////////////////////// + +// Copyright 2007-2010 Baptiste Lepilleur +// Distributed under MIT license, or public domain if desired and +// recognized in your jurisdiction. +// See file LICENSE for detail or copy at http://jsoncpp.sourceforge.net/LICENSE + +#ifndef JSON_CONFIG_H_INCLUDED +#define JSON_CONFIG_H_INCLUDED + +/// If defined, indicates that json library is embedded in CppTL library. +//# define JSON_IN_CPPTL 1 + +/// If defined, indicates that json may leverage CppTL library +//# define JSON_USE_CPPTL 1 +/// If defined, indicates that cpptl vector based map should be used instead of +/// std::map +/// as Value container. +//# define JSON_USE_CPPTL_SMALLMAP 1 + +// If non-zero, the library uses exceptions to report bad input instead of C +// assertion macros. The default is to use exceptions. +#ifndef JSON_USE_EXCEPTION +#define JSON_USE_EXCEPTION 1 +#endif + +/// If defined, indicates that the source file is amalgated +/// to prevent private header inclusion. +/// Remarks: it is automatically defined in the generated amalgated header. +// #define JSON_IS_AMALGAMATION + +#ifdef JSON_IN_CPPTL +#include +#ifndef JSON_USE_CPPTL +#define JSON_USE_CPPTL 1 +#endif +#endif + +#ifdef JSON_IN_CPPTL +#define JSON_API CPPTL_API +#elif defined(JSON_DLL_BUILD) +#if defined(_MSC_VER) +#define JSON_API __declspec(dllexport) +#define JSONCPP_DISABLE_DLL_INTERFACE_WARNING +#endif // if defined(_MSC_VER) +#elif defined(JSON_DLL) +#if defined(_MSC_VER) +#define JSON_API __declspec(dllimport) +#define JSONCPP_DISABLE_DLL_INTERFACE_WARNING +#endif // if defined(_MSC_VER) +#endif // ifdef JSON_IN_CPPTL +#if !defined(JSON_API) +#define JSON_API +#endif + +#if !defined(JSON_HAS_UNIQUE_PTR) +#if __cplusplus >= 201103L +#define JSON_HAS_UNIQUE_PTR (1) +#elif _MSC_VER >= 1600 +#define JSON_HAS_UNIQUE_PTR (1) +#else +#define JSON_HAS_UNIQUE_PTR (0) +#endif +#endif + +// If JSON_NO_INT64 is defined, then Json only support C++ "int" type for +// integer +// Storages, and 64 bits integer support is disabled. +// #define JSON_NO_INT64 1 + +#if defined(_MSC_VER) && _MSC_VER <= 1200 // MSVC 6 +// Microsoft Visual Studio 6 only support conversion from __int64 to double +// (no conversion from unsigned __int64). +#define JSON_USE_INT64_DOUBLE_CONVERSION 1 +// Disable warning 4786 for VS6 caused by STL (identifier was truncated to '255' +// characters in the debug information) +// All projects I've ever seen with VS6 were using this globally (not bothering +// with pragma push/pop). +#pragma warning(disable : 4786) +#endif // if defined(_MSC_VER) && _MSC_VER < 1200 // MSVC 6 + +#if defined(_MSC_VER) && _MSC_VER >= 1500 // MSVC 2008 +/// Indicates that the following function is deprecated. +#define JSONCPP_DEPRECATED(message) __declspec(deprecated(message)) +#elif defined(__clang__) && defined(__has_feature) +#if __has_feature(attribute_deprecated_with_message) +#define JSONCPP_DEPRECATED(message) __attribute__ ((deprecated(message))) +#endif +#elif defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 5)) +#define JSONCPP_DEPRECATED(message) __attribute__ ((deprecated(message))) +#elif defined(__GNUC__) && (__GNUC__ > 3 || (__GNUC__ == 3 && __GNUC_MINOR__ >= 1)) +#define JSONCPP_DEPRECATED(message) __attribute__((__deprecated__)) +#endif + +#if !defined(JSONCPP_DEPRECATED) +#define JSONCPP_DEPRECATED(message) +#endif // if !defined(JSONCPP_DEPRECATED) + +namespace Json { +typedef int Int; +typedef unsigned int UInt; +#if defined(JSON_NO_INT64) +typedef int LargestInt; +typedef unsigned int LargestUInt; +#undef JSON_HAS_INT64 +#else // if defined(JSON_NO_INT64) +// For Microsoft Visual use specific types as long long is not supported +#if defined(_MSC_VER) // Microsoft Visual Studio +typedef __int64 Int64; +typedef unsigned __int64 UInt64; +#else // if defined(_MSC_VER) // Other platforms, use long long +typedef long long int Int64; +typedef unsigned long long int UInt64; +#endif // if defined(_MSC_VER) +typedef Int64 LargestInt; +typedef UInt64 LargestUInt; +#define JSON_HAS_INT64 +#endif // if defined(JSON_NO_INT64) +} // end namespace Json + +#endif // JSON_CONFIG_H_INCLUDED + +// ////////////////////////////////////////////////////////////////////// +// End of content of file: include/json/config.h +// ////////////////////////////////////////////////////////////////////// + + + + + + +// ////////////////////////////////////////////////////////////////////// +// Beginning of content of file: include/json/forwards.h +// ////////////////////////////////////////////////////////////////////// + +// Copyright 2007-2010 Baptiste Lepilleur +// Distributed under MIT license, or public domain if desired and +// recognized in your jurisdiction. +// See file LICENSE for detail or copy at http://jsoncpp.sourceforge.net/LICENSE + +#ifndef JSON_FORWARDS_H_INCLUDED +#define JSON_FORWARDS_H_INCLUDED + +#if !defined(JSON_IS_AMALGAMATION) +#include "config.h" +#endif // if !defined(JSON_IS_AMALGAMATION) + +namespace Json { + +// writer.h +class FastWriter; +class StyledWriter; + +// reader.h +class Reader; + +// features.h +class Features; + +// value.h +typedef unsigned int ArrayIndex; +class StaticString; +class Path; +class PathArgument; +class Value; +class ValueIteratorBase; +class ValueIterator; +class ValueConstIterator; + +} // namespace Json + +#endif // JSON_FORWARDS_H_INCLUDED + +// ////////////////////////////////////////////////////////////////////// +// End of content of file: include/json/forwards.h +// ////////////////////////////////////////////////////////////////////// + + + + + + +// ////////////////////////////////////////////////////////////////////// +// Beginning of content of file: include/json/features.h +// ////////////////////////////////////////////////////////////////////// + +// Copyright 2007-2010 Baptiste Lepilleur +// Distributed under MIT license, or public domain if desired and +// recognized in your jurisdiction. +// See file LICENSE for detail or copy at http://jsoncpp.sourceforge.net/LICENSE + +#ifndef CPPTL_JSON_FEATURES_H_INCLUDED +#define CPPTL_JSON_FEATURES_H_INCLUDED + +#if !defined(JSON_IS_AMALGAMATION) +#include "forwards.h" +#endif // if !defined(JSON_IS_AMALGAMATION) + +namespace Json { + +/** \brief Configuration passed to reader and writer. + * This configuration object can be used to force the Reader or Writer + * to behave in a standard conforming way. + */ +class JSON_API Features { +public: + /** \brief A configuration that allows all features and assumes all strings + * are UTF-8. + * - C & C++ comments are allowed + * - Root object can be any JSON value + * - Assumes Value strings are encoded in UTF-8 + */ + static Features all(); + + /** \brief A configuration that is strictly compatible with the JSON + * specification. + * - Comments are forbidden. + * - Root object must be either an array or an object value. + * - Assumes Value strings are encoded in UTF-8 + */ + static Features strictMode(); + + /** \brief Initialize the configuration like JsonConfig::allFeatures; + */ + Features(); + + /// \c true if comments are allowed. Default: \c true. + bool allowComments_; + + /// \c true if root must be either an array or an object value. Default: \c + /// false. + bool strictRoot_; +}; + +} // namespace Json + +#endif // CPPTL_JSON_FEATURES_H_INCLUDED + +// ////////////////////////////////////////////////////////////////////// +// End of content of file: include/json/features.h +// ////////////////////////////////////////////////////////////////////// + + + + + + +// ////////////////////////////////////////////////////////////////////// +// Beginning of content of file: include/json/value.h +// ////////////////////////////////////////////////////////////////////// + +// Copyright 2007-2010 Baptiste Lepilleur +// Distributed under MIT license, or public domain if desired and +// recognized in your jurisdiction. +// See file LICENSE for detail or copy at http://jsoncpp.sourceforge.net/LICENSE + +#ifndef CPPTL_JSON_H_INCLUDED +#define CPPTL_JSON_H_INCLUDED + +#if !defined(JSON_IS_AMALGAMATION) +#include "forwards.h" +#endif // if !defined(JSON_IS_AMALGAMATION) +#include +#include +#include + +#ifndef JSON_USE_CPPTL_SMALLMAP +#include +#else +#include +#endif +#ifdef JSON_USE_CPPTL +#include +#endif + +// Disable warning C4251: : needs to have dll-interface to +// be used by... +#if defined(JSONCPP_DISABLE_DLL_INTERFACE_WARNING) +#pragma warning(push) +#pragma warning(disable : 4251) +#endif // if defined(JSONCPP_DISABLE_DLL_INTERFACE_WARNING) + +//Conditional NORETURN attribute on the throw functions would: +// a) suppress false positives from static code analysis +// b) possibly improve optimization opportunities. +#if !defined(JSONCPP_NORETURN) +# if defined(_MSC_VER) +# define JSONCPP_NORETURN __declspec(noreturn) +# elif defined(__GNUC__) +# define JSONCPP_NORETURN __attribute__ ((__noreturn__)) +# else +# define JSONCPP_NORETURN +# endif +#endif + +/** \brief JSON (JavaScript Object Notation). + */ +namespace Json { + +/** Base class for all exceptions we throw. + * + * We use nothing but these internally. Of course, STL can throw others. + */ +class JSON_API Exception : public std::exception { +public: + Exception(std::string const& msg); + virtual ~Exception() throw(); + virtual char const* what() const throw(); +protected: + std::string const msg_; +}; + +/** Exceptions which the user cannot easily avoid. + * + * E.g. out-of-memory (when we use malloc), stack-overflow, malicious input + * + * \remark derived from Json::Exception + */ +class JSON_API RuntimeError : public Exception { +public: + RuntimeError(std::string const& msg); +}; + +/** Exceptions thrown by JSON_ASSERT/JSON_FAIL macros. + * + * These are precondition-violations (user bugs) and internal errors (our bugs). + * + * \remark derived from Json::Exception + */ +class JSON_API LogicError : public Exception { +public: + LogicError(std::string const& msg); +}; + +/// used internally +JSONCPP_NORETURN void throwRuntimeError(std::string const& msg); +/// used internally +JSONCPP_NORETURN void throwLogicError(std::string const& msg); + +/** \brief Type of the value held by a Value object. + */ +enum ValueType { + nullValue = 0, ///< 'null' value + intValue, ///< signed integer value + uintValue, ///< unsigned integer value + realValue, ///< double value + stringValue, ///< UTF-8 string value + booleanValue, ///< bool value + arrayValue, ///< array value (ordered list) + objectValue ///< object value (collection of name/value pairs). +}; + +enum CommentPlacement { + commentBefore = 0, ///< a comment placed on the line before a value + commentAfterOnSameLine, ///< a comment just after a value on the same line + commentAfter, ///< a comment on the line after a value (only make sense for + /// root value) + numberOfCommentPlacement +}; + +//# ifdef JSON_USE_CPPTL +// typedef CppTL::AnyEnumerator EnumMemberNames; +// typedef CppTL::AnyEnumerator EnumValues; +//# endif + +/** \brief Lightweight wrapper to tag static string. + * + * Value constructor and objectValue member assignement takes advantage of the + * StaticString and avoid the cost of string duplication when storing the + * string or the member name. + * + * Example of usage: + * \code + * Json::Value aValue( StaticString("some text") ); + * Json::Value object; + * static const StaticString code("code"); + * object[code] = 1234; + * \endcode + */ +class JSON_API StaticString { +public: + explicit StaticString(const char* czstring) : c_str_(czstring) {} + + operator const char*() const { return c_str_; } + + const char* c_str() const { return c_str_; } + +private: + const char* c_str_; +}; + +/** \brief Represents a JSON value. + * + * This class is a discriminated union wrapper that can represents a: + * - signed integer [range: Value::minInt - Value::maxInt] + * - unsigned integer (range: 0 - Value::maxUInt) + * - double + * - UTF-8 string + * - boolean + * - 'null' + * - an ordered list of Value + * - collection of name/value pairs (javascript object) + * + * The type of the held value is represented by a #ValueType and + * can be obtained using type(). + * + * Values of an #objectValue or #arrayValue can be accessed using operator[]() + * methods. + * Non-const methods will automatically create the a #nullValue element + * if it does not exist. + * The sequence of an #arrayValue will be automatically resized and initialized + * with #nullValue. resize() can be used to enlarge or truncate an #arrayValue. + * + * The get() methods can be used to obtain default value in the case the + * required element does not exist. + * + * It is possible to iterate over the list of a #objectValue values using + * the getMemberNames() method. + * + * \note #Value string-length fit in size_t, but keys must be < 2^30. + * (The reason is an implementation detail.) A #CharReader will raise an + * exception if a bound is exceeded to avoid security holes in your app, + * but the Value API does *not* check bounds. That is the responsibility + * of the caller. + */ +class JSON_API Value { + friend class ValueIteratorBase; +public: + typedef std::vector Members; + typedef ValueIterator iterator; + typedef ValueConstIterator const_iterator; + typedef Json::UInt UInt; + typedef Json::Int Int; +#if defined(JSON_HAS_INT64) + typedef Json::UInt64 UInt64; + typedef Json::Int64 Int64; +#endif // defined(JSON_HAS_INT64) + typedef Json::LargestInt LargestInt; + typedef Json::LargestUInt LargestUInt; + typedef Json::ArrayIndex ArrayIndex; + + static const Value& nullRef; +#if !defined(__ARMEL__) + /// \deprecated This exists for binary compatibility only. Use nullRef. + static const Value null; +#endif + /// Minimum signed integer value that can be stored in a Json::Value. + static const LargestInt minLargestInt; + /// Maximum signed integer value that can be stored in a Json::Value. + static const LargestInt maxLargestInt; + /// Maximum unsigned integer value that can be stored in a Json::Value. + static const LargestUInt maxLargestUInt; + + /// Minimum signed int value that can be stored in a Json::Value. + static const Int minInt; + /// Maximum signed int value that can be stored in a Json::Value. + static const Int maxInt; + /// Maximum unsigned int value that can be stored in a Json::Value. + static const UInt maxUInt; + +#if defined(JSON_HAS_INT64) + /// Minimum signed 64 bits int value that can be stored in a Json::Value. + static const Int64 minInt64; + /// Maximum signed 64 bits int value that can be stored in a Json::Value. + static const Int64 maxInt64; + /// Maximum unsigned 64 bits int value that can be stored in a Json::Value. + static const UInt64 maxUInt64; +#endif // defined(JSON_HAS_INT64) + +//MW: workaround for bug in NVIDIAs CUDA 7.5 nvcc compiler +#ifdef __NVCC__ +public: +#else +private: +#endif //__NVCC__ +#ifndef JSONCPP_DOC_EXCLUDE_IMPLEMENTATION + class CZString { + public: + enum DuplicationPolicy { + noDuplication = 0, + duplicate, + duplicateOnCopy + }; + CZString(ArrayIndex index); + CZString(char const* str, unsigned length, DuplicationPolicy allocate); + CZString(CZString const& other); + ~CZString(); + CZString& operator=(CZString other); + bool operator<(CZString const& other) const; + bool operator==(CZString const& other) const; + ArrayIndex index() const; + //const char* c_str() const; ///< \deprecated + char const* data() const; + unsigned length() const; + bool isStaticString() const; + + private: + void swap(CZString& other); + + struct StringStorage { + unsigned policy_: 2; + unsigned length_: 30; // 1GB max + }; + + char const* cstr_; // actually, a prefixed string, unless policy is noDup + union { + ArrayIndex index_; + StringStorage storage_; + }; + }; + +public: +#ifndef JSON_USE_CPPTL_SMALLMAP + typedef std::map ObjectValues; +#else + typedef CppTL::SmallMap ObjectValues; +#endif // ifndef JSON_USE_CPPTL_SMALLMAP +#endif // ifndef JSONCPP_DOC_EXCLUDE_IMPLEMENTATION + +public: + /** \brief Create a default Value of the given type. + + This is a very useful constructor. + To create an empty array, pass arrayValue. + To create an empty object, pass objectValue. + Another Value can then be set to this one by assignment. +This is useful since clear() and resize() will not alter types. + + Examples: +\code +Json::Value null_value; // null +Json::Value arr_value(Json::arrayValue); // [] +Json::Value obj_value(Json::objectValue); // {} +\endcode + */ + Value(ValueType type = nullValue); + Value(Int value); + Value(UInt value); +#if defined(JSON_HAS_INT64) + Value(Int64 value); + Value(UInt64 value); +#endif // if defined(JSON_HAS_INT64) + Value(double value); + Value(const char* value); ///< Copy til first 0. (NULL causes to seg-fault.) + Value(const char* begin, const char* end); ///< Copy all, incl zeroes. + /** \brief Constructs a value from a static string. + + * Like other value string constructor but do not duplicate the string for + * internal storage. The given string must remain alive after the call to this + * constructor. + * \note This works only for null-terminated strings. (We cannot change the + * size of this class, so we have nowhere to store the length, + * which might be computed later for various operations.) + * + * Example of usage: + * \code + * static StaticString foo("some text"); + * Json::Value aValue(foo); + * \endcode + */ + Value(const StaticString& value); + Value(const std::string& value); ///< Copy data() til size(). Embedded zeroes too. +#ifdef JSON_USE_CPPTL + Value(const CppTL::ConstString& value); +#endif + Value(bool value); + /// Deep copy. + Value(const Value& other); + ~Value(); + + /// Deep copy, then swap(other). + /// \note Over-write existing comments. To preserve comments, use #swapPayload(). + Value &operator=(const Value &other); + /// Swap everything. + void swap(Value& other); + /// Swap values but leave comments and source offsets in place. + void swapPayload(Value& other); + + ValueType type() const; + + /// Compare payload only, not comments etc. + bool operator<(const Value& other) const; + bool operator<=(const Value& other) const; + bool operator>=(const Value& other) const; + bool operator>(const Value& other) const; + bool operator==(const Value& other) const; + bool operator!=(const Value& other) const; + int compare(const Value& other) const; + + const char* asCString() const; ///< Embedded zeroes could cause you trouble! + std::string asString() const; ///< Embedded zeroes are possible. + /** Get raw char* of string-value. + * \return false if !string. (Seg-fault if str or end are NULL.) + */ + bool getString( + char const** begin, char const** end) const; +#ifdef JSON_USE_CPPTL + CppTL::ConstString asConstString() const; +#endif + Int asInt() const; + UInt asUInt() const; +#if defined(JSON_HAS_INT64) + Int64 asInt64() const; + UInt64 asUInt64() const; +#endif // if defined(JSON_HAS_INT64) + LargestInt asLargestInt() const; + LargestUInt asLargestUInt() const; + float asFloat() const; + double asDouble() const; + bool asBool() const; + + bool isNull() const; + bool isBool() const; + bool isInt() const; + bool isInt64() const; + bool isUInt() const; + bool isUInt64() const; + bool isIntegral() const; + bool isDouble() const; + bool isNumeric() const; + bool isString() const; + bool isArray() const; + bool isObject() const; + + bool isConvertibleTo(ValueType other) const; + + /// Number of values in array or object + ArrayIndex size() const; + + /// \brief Return true if empty array, empty object, or null; + /// otherwise, false. + bool empty() const; + + /// Return isNull() + bool operator!() const; + + /// Remove all object members and array elements. + /// \pre type() is arrayValue, objectValue, or nullValue + /// \post type() is unchanged + void clear(); + + /// Resize the array to size elements. + /// New elements are initialized to null. + /// May only be called on nullValue or arrayValue. + /// \pre type() is arrayValue or nullValue + /// \post type() is arrayValue + void resize(ArrayIndex size); + + /// Access an array element (zero based index ). + /// If the array contains less than index element, then null value are + /// inserted + /// in the array so that its size is index+1. + /// (You may need to say 'value[0u]' to get your compiler to distinguish + /// this from the operator[] which takes a string.) + Value& operator[](ArrayIndex index); + + /// Access an array element (zero based index ). + /// If the array contains less than index element, then null value are + /// inserted + /// in the array so that its size is index+1. + /// (You may need to say 'value[0u]' to get your compiler to distinguish + /// this from the operator[] which takes a string.) + Value& operator[](int index); + + /// Access an array element (zero based index ) + /// (You may need to say 'value[0u]' to get your compiler to distinguish + /// this from the operator[] which takes a string.) + const Value& operator[](ArrayIndex index) const; + + /// Access an array element (zero based index ) + /// (You may need to say 'value[0u]' to get your compiler to distinguish + /// this from the operator[] which takes a string.) + const Value& operator[](int index) const; + + /// If the array contains at least index+1 elements, returns the element + /// value, + /// otherwise returns defaultValue. + Value get(ArrayIndex index, const Value& defaultValue) const; + /// Return true if index < size(). + bool isValidIndex(ArrayIndex index) const; + /// \brief Append value to array at the end. + /// + /// Equivalent to jsonvalue[jsonvalue.size()] = value; + Value& append(const Value& value); + + /// Access an object value by name, create a null member if it does not exist. + /// \note Because of our implementation, keys are limited to 2^30 -1 chars. + /// Exceeding that will cause an exception. + Value& operator[](const char* key); + /// Access an object value by name, returns null if there is no member with + /// that name. + const Value& operator[](const char* key) const; + /// Access an object value by name, create a null member if it does not exist. + /// \param key may contain embedded nulls. + Value& operator[](const std::string& key); + /// Access an object value by name, returns null if there is no member with + /// that name. + /// \param key may contain embedded nulls. + const Value& operator[](const std::string& key) const; + /** \brief Access an object value by name, create a null member if it does not + exist. + + * If the object has no entry for that name, then the member name used to store + * the new entry is not duplicated. + * Example of use: + * \code + * Json::Value object; + * static const StaticString code("code"); + * object[code] = 1234; + * \endcode + */ + Value& operator[](const StaticString& key); +#ifdef JSON_USE_CPPTL + /// Access an object value by name, create a null member if it does not exist. + Value& operator[](const CppTL::ConstString& key); + /// Access an object value by name, returns null if there is no member with + /// that name. + const Value& operator[](const CppTL::ConstString& key) const; +#endif + /// Return the member named key if it exist, defaultValue otherwise. + /// \note deep copy + Value get(const char* key, const Value& defaultValue) const; + /// Return the member named key if it exist, defaultValue otherwise. + /// \note deep copy + /// \note key may contain embedded nulls. + Value get(const char* begin, const char* end, const Value& defaultValue) const; + /// Return the member named key if it exist, defaultValue otherwise. + /// \note deep copy + /// \param key may contain embedded nulls. + Value get(const std::string& key, const Value& defaultValue) const; +#ifdef JSON_USE_CPPTL + /// Return the member named key if it exist, defaultValue otherwise. + /// \note deep copy + Value get(const CppTL::ConstString& key, const Value& defaultValue) const; +#endif + /// Most general and efficient version of isMember()const, get()const, + /// and operator[]const + /// \note As stated elsewhere, behavior is undefined if (end-begin) >= 2^30 + Value const* find(char const* begin, char const* end) const; + /// Most general and efficient version of object-mutators. + /// \note As stated elsewhere, behavior is undefined if (end-begin) >= 2^30 + /// \return non-zero, but JSON_ASSERT if this is neither object nor nullValue. + Value const* demand(char const* begin, char const* end); + /// \brief Remove and return the named member. + /// + /// Do nothing if it did not exist. + /// \return the removed Value, or null. + /// \pre type() is objectValue or nullValue + /// \post type() is unchanged + /// \deprecated + Value removeMember(const char* key); + /// Same as removeMember(const char*) + /// \param key may contain embedded nulls. + /// \deprecated + Value removeMember(const std::string& key); + /// Same as removeMember(const char* begin, const char* end, Value* removed), + /// but 'key' is null-terminated. + bool removeMember(const char* key, Value* removed); + /** \brief Remove the named map member. + + Update 'removed' iff removed. + \param key may contain embedded nulls. + \return true iff removed (no exceptions) + */ + bool removeMember(std::string const& key, Value* removed); + /// Same as removeMember(std::string const& key, Value* removed) + bool removeMember(const char* begin, const char* end, Value* removed); + /** \brief Remove the indexed array element. + + O(n) expensive operations. + Update 'removed' iff removed. + \return true iff removed (no exceptions) + */ + bool removeIndex(ArrayIndex i, Value* removed); + + /// Return true if the object has a member named key. + /// \note 'key' must be null-terminated. + bool isMember(const char* key) const; + /// Return true if the object has a member named key. + /// \param key may contain embedded nulls. + bool isMember(const std::string& key) const; + /// Same as isMember(std::string const& key)const + bool isMember(const char* begin, const char* end) const; +#ifdef JSON_USE_CPPTL + /// Return true if the object has a member named key. + bool isMember(const CppTL::ConstString& key) const; +#endif + + /// \brief Return a list of the member names. + /// + /// If null, return an empty list. + /// \pre type() is objectValue or nullValue + /// \post if type() was nullValue, it remains nullValue + Members getMemberNames() const; + + //# ifdef JSON_USE_CPPTL + // EnumMemberNames enumMemberNames() const; + // EnumValues enumValues() const; + //# endif + + /// \deprecated Always pass len. + JSONCPP_DEPRECATED("Use setComment(std::string const&) instead.") + void setComment(const char* comment, CommentPlacement placement); + /// Comments must be //... or /* ... */ + void setComment(const char* comment, size_t len, CommentPlacement placement); + /// Comments must be //... or /* ... */ + void setComment(const std::string& comment, CommentPlacement placement); + bool hasComment(CommentPlacement placement) const; + /// Include delimiters and embedded newlines. + std::string getComment(CommentPlacement placement) const; + + std::string toStyledString() const; + + const_iterator begin() const; + const_iterator end() const; + + iterator begin(); + iterator end(); + +private: + void initBasic(ValueType type, bool allocated = false); + + Value& resolveReference(const char* key); + Value& resolveReference(const char* key, const char* end); + + struct CommentInfo { + CommentInfo(); + ~CommentInfo(); + + void setComment(const char* text, size_t len); + + char* comment_; + }; + + // struct MemberNamesTransform + //{ + // typedef const char *result_type; + // const char *operator()( const CZString &name ) const + // { + // return name.c_str(); + // } + //}; + + union ValueHolder { + LargestInt int_; + LargestUInt uint_; + double real_; + bool bool_; + char* string_; // actually ptr to unsigned, followed by str, unless !allocated_ + ObjectValues* map_; + } value_; + ValueType type_ : 8; + unsigned int allocated_ : 1; // Notes: if declared as bool, bitfield is useless. + // If not allocated_, string_ must be null-terminated. + CommentInfo* comments_; +}; + +/** \brief Experimental and untested: represents an element of the "path" to + * access a node. + */ +class JSON_API PathArgument { +public: + friend class Path; + + PathArgument(); + PathArgument(ArrayIndex index); + PathArgument(const char* key); + PathArgument(const std::string& key); + +private: + enum Kind { + kindNone = 0, + kindIndex, + kindKey + }; + std::string key_; + ArrayIndex index_; + Kind kind_; +}; + +/** \brief Experimental and untested: represents a "path" to access a node. + * + * Syntax: + * - "." => root node + * - ".[n]" => elements at index 'n' of root node (an array value) + * - ".name" => member named 'name' of root node (an object value) + * - ".name1.name2.name3" + * - ".[0][1][2].name1[3]" + * - ".%" => member name is provided as parameter + * - ".[%]" => index is provied as parameter + */ +class JSON_API Path { +public: + Path(const std::string& path, + const PathArgument& a1 = PathArgument(), + const PathArgument& a2 = PathArgument(), + const PathArgument& a3 = PathArgument(), + const PathArgument& a4 = PathArgument(), + const PathArgument& a5 = PathArgument()); + + const Value& resolve(const Value& root) const; + Value resolve(const Value& root, const Value& defaultValue) const; + /// Creates the "path" to access the specified node and returns a reference on + /// the node. + Value& make(Value& root) const; + +private: + typedef std::vector InArgs; + typedef std::vector Args; + + void makePath(const std::string& path, const InArgs& in); + void addPathInArg(const std::string& path, + const InArgs& in, + InArgs::const_iterator& itInArg, + PathArgument::Kind kind); + void invalidPath(const std::string& path, int location); + + Args args_; +}; + +/** \brief base class for Value iterators. + * + */ +class JSON_API ValueIteratorBase { +public: + typedef std::bidirectional_iterator_tag iterator_category; + typedef unsigned int size_t; + typedef int difference_type; + typedef ValueIteratorBase SelfType; + + bool operator==(const SelfType& other) const { return isEqual(other); } + + bool operator!=(const SelfType& other) const { return !isEqual(other); } + + difference_type operator-(const SelfType& other) const { + return other.computeDistance(*this); + } + + /// Return either the index or the member name of the referenced value as a + /// Value. + Value key() const; + + /// Return the index of the referenced Value, or -1 if it is not an arrayValue. + UInt index() const; + + /// Return the member name of the referenced Value, or "" if it is not an + /// objectValue. + /// \note Avoid `c_str()` on result, as embedded zeroes are possible. + std::string name() const; + + /// Return the member name of the referenced Value. "" if it is not an + /// objectValue. + /// \deprecated This cannot be used for UTF-8 strings, since there can be embedded nulls. + JSONCPP_DEPRECATED("Use `key = name();` instead.") + char const* memberName() const; + /// Return the member name of the referenced Value, or NULL if it is not an + /// objectValue. + /// \note Better version than memberName(). Allows embedded nulls. + char const* memberName(char const** end) const; + +protected: + Value& deref() const; + + void increment(); + + void decrement(); + + difference_type computeDistance(const SelfType& other) const; + + bool isEqual(const SelfType& other) const; + + void copy(const SelfType& other); + +private: + Value::ObjectValues::iterator current_; + // Indicates that iterator is for a null value. + bool isNull_; + +public: + // For some reason, BORLAND needs these at the end, rather + // than earlier. No idea why. + ValueIteratorBase(); + explicit ValueIteratorBase(const Value::ObjectValues::iterator& current); +}; + +/** \brief const iterator for object and array value. + * + */ +class JSON_API ValueConstIterator : public ValueIteratorBase { + friend class Value; + +public: + typedef const Value value_type; + //typedef unsigned int size_t; + //typedef int difference_type; + typedef const Value& reference; + typedef const Value* pointer; + typedef ValueConstIterator SelfType; + + ValueConstIterator(); + +private: +/*! \internal Use by Value to create an iterator. + */ + explicit ValueConstIterator(const Value::ObjectValues::iterator& current); +public: + SelfType& operator=(const ValueIteratorBase& other); + + SelfType operator++(int) { + SelfType temp(*this); + ++*this; + return temp; + } + + SelfType operator--(int) { + SelfType temp(*this); + --*this; + return temp; + } + + SelfType& operator--() { + decrement(); + return *this; + } + + SelfType& operator++() { + increment(); + return *this; + } + + reference operator*() const { return deref(); } + + pointer operator->() const { return &deref(); } +}; + +/** \brief Iterator for object and array value. + */ +class JSON_API ValueIterator : public ValueIteratorBase { + friend class Value; + +public: + typedef Value value_type; + typedef unsigned int size_t; + typedef int difference_type; + typedef Value& reference; + typedef Value* pointer; + typedef ValueIterator SelfType; + + ValueIterator(); + ValueIterator(const ValueConstIterator& other); + ValueIterator(const ValueIterator& other); + +private: +/*! \internal Use by Value to create an iterator. + */ + explicit ValueIterator(const Value::ObjectValues::iterator& current); +public: + SelfType& operator=(const SelfType& other); + + SelfType operator++(int) { + SelfType temp(*this); + ++*this; + return temp; + } + + SelfType operator--(int) { + SelfType temp(*this); + --*this; + return temp; + } + + SelfType& operator--() { + decrement(); + return *this; + } + + SelfType& operator++() { + increment(); + return *this; + } + + reference operator*() const { return deref(); } + + pointer operator->() const { return &deref(); } +}; + +} // namespace Json + + +namespace std { +/// Specialize std::swap() for Json::Value. +template<> +inline void swap(Json::Value& a, Json::Value& b) { a.swap(b); } +} + + +#if defined(JSONCPP_DISABLE_DLL_INTERFACE_WARNING) +#pragma warning(pop) +#endif // if defined(JSONCPP_DISABLE_DLL_INTERFACE_WARNING) + +#endif // CPPTL_JSON_H_INCLUDED + +// ////////////////////////////////////////////////////////////////////// +// End of content of file: include/json/value.h +// ////////////////////////////////////////////////////////////////////// + + + + + + +// ////////////////////////////////////////////////////////////////////// +// Beginning of content of file: include/json/reader.h +// ////////////////////////////////////////////////////////////////////// + +// Copyright 2007-2010 Baptiste Lepilleur +// Distributed under MIT license, or public domain if desired and +// recognized in your jurisdiction. +// See file LICENSE for detail or copy at http://jsoncpp.sourceforge.net/LICENSE + +#ifndef CPPTL_JSON_READER_H_INCLUDED +#define CPPTL_JSON_READER_H_INCLUDED + +#if !defined(JSON_IS_AMALGAMATION) +#include "features.h" +#include "value.h" +#endif // if !defined(JSON_IS_AMALGAMATION) +#include +#include +#include +#include +#include + +// Disable warning C4251: : needs to have dll-interface to +// be used by... +#if defined(JSONCPP_DISABLE_DLL_INTERFACE_WARNING) +#pragma warning(push) +#pragma warning(disable : 4251) +#endif // if defined(JSONCPP_DISABLE_DLL_INTERFACE_WARNING) + +namespace Json { + +/** \brief Unserialize a JSON document into a + *Value. + * + * \deprecated Use CharReader and CharReaderBuilder. + */ +class JSON_API Reader { +public: + typedef char Char; + typedef const Char* Location; + + /** \brief Constructs a Reader allowing all features + * for parsing. + */ + Reader(); + + /** \brief Constructs a Reader allowing the specified feature set + * for parsing. + */ + Reader(const Features& features); + + /** \brief Read a Value from a JSON + * document. + * \param document UTF-8 encoded string containing the document to read. + * \param root [out] Contains the root value of the document if it was + * successfully parsed. + * \param collectComments \c true to collect comment and allow writing them + * back during + * serialization, \c false to discard comments. + * This parameter is ignored if + * Features::allowComments_ + * is \c false. + * \return \c true if the document was successfully parsed, \c false if an + * error occurred. + */ + bool + parse(const std::string& document, Value& root, bool collectComments = true); + + /** \brief Read a Value from a JSON + document. + * \param beginDoc Pointer on the beginning of the UTF-8 encoded string of the + document to read. + * \param endDoc Pointer on the end of the UTF-8 encoded string of the + document to read. + * Must be >= beginDoc. + * \param root [out] Contains the root value of the document if it was + * successfully parsed. + * \param collectComments \c true to collect comment and allow writing them + back during + * serialization, \c false to discard comments. + * This parameter is ignored if + Features::allowComments_ + * is \c false. + * \return \c true if the document was successfully parsed, \c false if an + error occurred. + */ + bool parse(const char* beginDoc, + const char* endDoc, + Value& root, + bool collectComments = true); + + /// \brief Parse from input stream. + /// \see Json::operator>>(std::istream&, Json::Value&). + bool parse(std::istream& is, Value& root, bool collectComments = true); + + /** \brief Returns a user friendly string that list errors in the parsed + * document. + * \return Formatted error message with the list of errors with their location + * in + * the parsed document. An empty string is returned if no error + * occurred + * during parsing. + * \deprecated Use getFormattedErrorMessages() instead (typo fix). + */ + JSONCPP_DEPRECATED("Use getFormattedErrorMessages() instead.") + std::string getFormatedErrorMessages() const; + + /** \brief Returns a user friendly string that list errors in the parsed + * document. + * \return Formatted error message with the list of errors with their location + * in + * the parsed document. An empty string is returned if no error + * occurred + * during parsing. + */ + std::string getFormattedErrorMessages() const; + +private: + enum TokenType { + tokenEndOfStream = 0, + tokenObjectBegin, + tokenObjectEnd, + tokenArrayBegin, + tokenArrayEnd, + tokenString, + tokenNumber, + tokenTrue, + tokenFalse, + tokenNull, + tokenArraySeparator, + tokenMemberSeparator, + tokenComment, + tokenError + }; + + class Token { + public: + TokenType type_; + Location start_; + Location end_; + }; + + class ErrorInfo { + public: + Token token_; + std::string message_; + Location extra_; + }; + + typedef std::deque Errors; + + bool readToken(Token& token); + void skipSpaces(); + bool match(Location pattern, int patternLength); + bool readComment(); + bool readCStyleComment(); + bool readCppStyleComment(); + bool readString(); + void readNumber(); + bool readValue(); + bool readObject(Token& token); + bool readArray(Token& token); + bool decodeNumber(Token& token); + bool decodeNumber(Token& token, Value& decoded); + bool decodeString(Token& token); + bool decodeString(Token& token, std::string& decoded); + bool decodeDouble(Token& token); + bool decodeDouble(Token& token, Value& decoded); + bool decodeUnicodeCodePoint(Token& token, + Location& current, + Location end, + unsigned int& unicode); + bool decodeUnicodeEscapeSequence(Token& token, + Location& current, + Location end, + unsigned int& unicode); + bool addError(const std::string& message, Token& token, Location extra = 0); + bool recoverFromError(TokenType skipUntilToken); + bool addErrorAndRecover(const std::string& message, + Token& token, + TokenType skipUntilToken); + void skipUntilSpace(); + Value& currentValue(); + Char getNextChar(); + void + getLocationLineAndColumn(Location location, int& line, int& column) const; + std::string getLocationLineAndColumn(Location location) const; + void addComment(Location begin, Location end, CommentPlacement placement); + void skipCommentTokens(Token& token); + + typedef std::stack Nodes; + Nodes nodes_; + Errors errors_; + std::string document_; + Location begin_; + Location end_; + Location current_; + Location lastValueEnd_; + Value* lastValue_; + std::string commentsBefore_; + Features features_; + bool collectComments_; +}; // Reader + +/** Interface for reading JSON from a char array. + */ +class JSON_API CharReader { +public: + virtual ~CharReader() {} + /** \brief Read a Value from a JSON + document. + * The document must be a UTF-8 encoded string containing the document to read. + * + * \param beginDoc Pointer on the beginning of the UTF-8 encoded string of the + document to read. + * \param endDoc Pointer on the end of the UTF-8 encoded string of the + document to read. + * Must be >= beginDoc. + * \param root [out] Contains the root value of the document if it was + * successfully parsed. + * \param errs [out] Formatted error messages (if not NULL) + * a user friendly string that lists errors in the parsed + * document. + * \return \c true if the document was successfully parsed, \c false if an + error occurred. + */ + virtual bool parse( + char const* beginDoc, char const* endDoc, + Value* root, std::string* errs) = 0; + + class Factory { + public: + virtual ~Factory() {} + /** \brief Allocate a CharReader via operator new(). + * \throw std::exception if something goes wrong (e.g. invalid settings) + */ + virtual CharReader* newCharReader() const = 0; + }; // Factory +}; // CharReader + +/** \brief Build a CharReader implementation. + +Usage: +\code + using namespace Json; + CharReaderBuilder builder; + builder["collectComments"] = false; + Value value; + std::string errs; + bool ok = parseFromStream(builder, std::cin, &value, &errs); +\endcode +*/ +class JSON_API CharReaderBuilder : public CharReader::Factory { +public: + // Note: We use a Json::Value so that we can add data-members to this class + // without a major version bump. + /** Configuration of this builder. + These are case-sensitive. + Available settings (case-sensitive): + - `"collectComments": false or true` + - true to collect comment and allow writing them + back during serialization, false to discard comments. + This parameter is ignored if allowComments is false. + - `"allowComments": false or true` + - true if comments are allowed. + - `"strictRoot": false or true` + - true if root must be either an array or an object value + - `"allowDroppedNullPlaceholders": false or true` + - true if dropped null placeholders are allowed. (See StreamWriterBuilder.) + - `"allowNumericKeys": false or true` + - true if numeric object keys are allowed. + - `"allowSingleQuotes": false or true` + - true if '' are allowed for strings (both keys and values) + - `"stackLimit": integer` + - Exceeding stackLimit (recursive depth of `readValue()`) will + cause an exception. + - This is a security issue (seg-faults caused by deeply nested JSON), + so the default is low. + - `"failIfExtra": false or true` + - If true, `parse()` returns false when extra non-whitespace trails + the JSON value in the input string. + - `"rejectDupKeys": false or true` + - If true, `parse()` returns false when a key is duplicated within an object. + - `"allowSpecialFloats": false or true` + - If true, special float values (NaNs and infinities) are allowed + and their values are lossfree restorable. + + You can examine 'settings_` yourself + to see the defaults. You can also write and read them just like any + JSON Value. + \sa setDefaults() + */ + Json::Value settings_; + + CharReaderBuilder(); + virtual ~CharReaderBuilder(); + + virtual CharReader* newCharReader() const; + + /** \return true if 'settings' are legal and consistent; + * otherwise, indicate bad settings via 'invalid'. + */ + bool validate(Json::Value* invalid) const; + + /** A simple way to update a specific setting. + */ + Value& operator[](std::string key); + + /** Called by ctor, but you can use this to reset settings_. + * \pre 'settings' != NULL (but Json::null is fine) + * \remark Defaults: + * \snippet src/lib_json/json_reader.cpp CharReaderBuilderDefaults + */ + static void setDefaults(Json::Value* settings); + /** Same as old Features::strictMode(). + * \pre 'settings' != NULL (but Json::null is fine) + * \remark Defaults: + * \snippet src/lib_json/json_reader.cpp CharReaderBuilderStrictMode + */ + static void strictMode(Json::Value* settings); +}; + +/** Consume entire stream and use its begin/end. + * Someday we might have a real StreamReader, but for now this + * is convenient. + */ +bool JSON_API parseFromStream( + CharReader::Factory const&, + std::istream&, + Value* root, std::string* errs); + +/** \brief Read from 'sin' into 'root'. + + Always keep comments from the input JSON. + + This can be used to read a file into a particular sub-object. + For example: + \code + Json::Value root; + cin >> root["dir"]["file"]; + cout << root; + \endcode + Result: + \verbatim + { + "dir": { + "file": { + // The input stream JSON would be nested here. + } + } + } + \endverbatim + \throw std::exception on parse error. + \see Json::operator<<() +*/ +JSON_API std::istream& operator>>(std::istream&, Value&); + +} // namespace Json + +#if defined(JSONCPP_DISABLE_DLL_INTERFACE_WARNING) +#pragma warning(pop) +#endif // if defined(JSONCPP_DISABLE_DLL_INTERFACE_WARNING) + +#endif // CPPTL_JSON_READER_H_INCLUDED + +// ////////////////////////////////////////////////////////////////////// +// End of content of file: include/json/reader.h +// ////////////////////////////////////////////////////////////////////// + + + + + + +// ////////////////////////////////////////////////////////////////////// +// Beginning of content of file: include/json/writer.h +// ////////////////////////////////////////////////////////////////////// + +// Copyright 2007-2010 Baptiste Lepilleur +// Distributed under MIT license, or public domain if desired and +// recognized in your jurisdiction. +// See file LICENSE for detail or copy at http://jsoncpp.sourceforge.net/LICENSE + +#ifndef JSON_WRITER_H_INCLUDED +#define JSON_WRITER_H_INCLUDED + +#if !defined(JSON_IS_AMALGAMATION) +#include "value.h" +#endif // if !defined(JSON_IS_AMALGAMATION) +#include +#include +#include + +// Disable warning C4251: : needs to have dll-interface to +// be used by... +#if defined(JSONCPP_DISABLE_DLL_INTERFACE_WARNING) +#pragma warning(push) +#pragma warning(disable : 4251) +#endif // if defined(JSONCPP_DISABLE_DLL_INTERFACE_WARNING) + +namespace Json { + +class Value; + +/** + +Usage: +\code + using namespace Json; + void writeToStdout(StreamWriter::Factory const& factory, Value const& value) { + std::unique_ptr const writer( + factory.newStreamWriter()); + writer->write(value, &std::cout); + std::cout << std::endl; // add lf and flush + } +\endcode +*/ +class JSON_API StreamWriter { +protected: + std::ostream* sout_; // not owned; will not delete +public: + StreamWriter(); + virtual ~StreamWriter(); + /** Write Value into document as configured in sub-class. + Do not take ownership of sout, but maintain a reference during function. + \pre sout != NULL + \return zero on success (For now, we always return zero, so check the stream instead.) + \throw std::exception possibly, depending on configuration + */ + virtual int write(Value const& root, std::ostream* sout) = 0; + + /** \brief A simple abstract factory. + */ + class JSON_API Factory { + public: + virtual ~Factory(); + /** \brief Allocate a CharReader via operator new(). + * \throw std::exception if something goes wrong (e.g. invalid settings) + */ + virtual StreamWriter* newStreamWriter() const = 0; + }; // Factory +}; // StreamWriter + +/** \brief Write into stringstream, then return string, for convenience. + * A StreamWriter will be created from the factory, used, and then deleted. + */ +std::string JSON_API writeString(StreamWriter::Factory const& factory, Value const& root); + + +/** \brief Build a StreamWriter implementation. + +Usage: +\code + using namespace Json; + Value value = ...; + StreamWriterBuilder builder; + builder["commentStyle"] = "None"; + builder["indentation"] = " "; // or whatever you like + std::unique_ptr writer( + builder.newStreamWriter()); + writer->write(value, &std::cout); + std::cout << std::endl; // add lf and flush +\endcode +*/ +class JSON_API StreamWriterBuilder : public StreamWriter::Factory { +public: + // Note: We use a Json::Value so that we can add data-members to this class + // without a major version bump. + /** Configuration of this builder. + Available settings (case-sensitive): + - "commentStyle": "None" or "All" + - "indentation": "" + - "enableYAMLCompatibility": false or true + - slightly change the whitespace around colons + - "dropNullPlaceholders": false or true + - Drop the "null" string from the writer's output for nullValues. + Strictly speaking, this is not valid JSON. But when the output is being + fed to a browser's Javascript, it makes for smaller output and the + browser can handle the output just fine. + - "useSpecialFloats": false or true + - If true, outputs non-finite floating point values in the following way: + NaN values as "NaN", positive infinity as "Infinity", and negative infinity + as "-Infinity". + + You can examine 'settings_` yourself + to see the defaults. You can also write and read them just like any + JSON Value. + \sa setDefaults() + */ + Json::Value settings_; + + StreamWriterBuilder(); + virtual ~StreamWriterBuilder(); + + /** + * \throw std::exception if something goes wrong (e.g. invalid settings) + */ + virtual StreamWriter* newStreamWriter() const; + + /** \return true if 'settings' are legal and consistent; + * otherwise, indicate bad settings via 'invalid'. + */ + bool validate(Json::Value* invalid) const; + /** A simple way to update a specific setting. + */ + Value& operator[](std::string key); + + /** Called by ctor, but you can use this to reset settings_. + * \pre 'settings' != NULL (but Json::null is fine) + * \remark Defaults: + * \snippet src/lib_json/json_writer.cpp StreamWriterBuilderDefaults + */ + static void setDefaults(Json::Value* settings); +}; + +/** \brief Abstract class for writers. + * \deprecated Use StreamWriter. (And really, this is an implementation detail.) + */ +class JSON_API Writer { +public: + virtual ~Writer(); + + virtual std::string write(const Value& root) = 0; +}; + +/** \brief Outputs a Value in JSON format + *without formatting (not human friendly). + * + * The JSON document is written in a single line. It is not intended for 'human' + *consumption, + * but may be usefull to support feature such as RPC where bandwith is limited. + * \sa Reader, Value + * \deprecated Use StreamWriterBuilder. + */ +class JSON_API FastWriter : public Writer { + +public: + FastWriter(); + virtual ~FastWriter() {} + + void enableYAMLCompatibility(); + +public: // overridden from Writer + virtual std::string write(const Value& root); + +private: + void writeValue(const Value& value); + + std::string document_; + bool yamlCompatiblityEnabled_; +}; + +/** \brief Writes a Value in JSON format in a + *human friendly way. + * + * The rules for line break and indent are as follow: + * - Object value: + * - if empty then print {} without indent and line break + * - if not empty the print '{', line break & indent, print one value per + *line + * and then unindent and line break and print '}'. + * - Array value: + * - if empty then print [] without indent and line break + * - if the array contains no object value, empty array or some other value + *types, + * and all the values fit on one lines, then print the array on a single + *line. + * - otherwise, it the values do not fit on one line, or the array contains + * object or non empty array, then print one value per line. + * + * If the Value have comments then they are outputed according to their + *#CommentPlacement. + * + * \sa Reader, Value, Value::setComment() + * \deprecated Use StreamWriterBuilder. + */ +class JSON_API StyledWriter : public Writer { +public: + StyledWriter(); + virtual ~StyledWriter() {} + +public: // overridden from Writer + /** \brief Serialize a Value in JSON format. + * \param root Value to serialize. + * \return String containing the JSON document that represents the root value. + */ + virtual std::string write(const Value& root); + +private: + void writeValue(const Value& value); + void writeArrayValue(const Value& value); + bool isMultineArray(const Value& value); + void pushValue(const std::string& value); + void writeIndent(); + void writeWithIndent(const std::string& value); + void indent(); + void unindent(); + void writeCommentBeforeValue(const Value& root); + void writeCommentAfterValueOnSameLine(const Value& root); + bool hasCommentForValue(const Value& value); + static std::string normalizeEOL(const std::string& text); + + typedef std::vector ChildValues; + + ChildValues childValues_; + std::string document_; + std::string indentString_; + int rightMargin_; + int indentSize_; + bool addChildValues_; +}; + +/** \brief Writes a Value in JSON format in a + human friendly way, + to a stream rather than to a string. + * + * The rules for line break and indent are as follow: + * - Object value: + * - if empty then print {} without indent and line break + * - if not empty the print '{', line break & indent, print one value per + line + * and then unindent and line break and print '}'. + * - Array value: + * - if empty then print [] without indent and line break + * - if the array contains no object value, empty array or some other value + types, + * and all the values fit on one lines, then print the array on a single + line. + * - otherwise, it the values do not fit on one line, or the array contains + * object or non empty array, then print one value per line. + * + * If the Value have comments then they are outputed according to their + #CommentPlacement. + * + * \param indentation Each level will be indented by this amount extra. + * \sa Reader, Value, Value::setComment() + * \deprecated Use StreamWriterBuilder. + */ +class JSON_API StyledStreamWriter { +public: + StyledStreamWriter(std::string indentation = "\t"); + ~StyledStreamWriter() {} + +public: + /** \brief Serialize a Value in JSON format. + * \param out Stream to write to. (Can be ostringstream, e.g.) + * \param root Value to serialize. + * \note There is no point in deriving from Writer, since write() should not + * return a value. + */ + void write(std::ostream& out, const Value& root); + +private: + void writeValue(const Value& value); + void writeArrayValue(const Value& value); + bool isMultineArray(const Value& value); + void pushValue(const std::string& value); + void writeIndent(); + void writeWithIndent(const std::string& value); + void indent(); + void unindent(); + void writeCommentBeforeValue(const Value& root); + void writeCommentAfterValueOnSameLine(const Value& root); + bool hasCommentForValue(const Value& value); + static std::string normalizeEOL(const std::string& text); + + typedef std::vector ChildValues; + + ChildValues childValues_; + std::ostream* document_; + std::string indentString_; + int rightMargin_; + std::string indentation_; + bool addChildValues_ : 1; + bool indented_ : 1; +}; + +#if defined(JSON_HAS_INT64) +std::string JSON_API valueToString(Int value); +std::string JSON_API valueToString(UInt value); +#endif // if defined(JSON_HAS_INT64) +std::string JSON_API valueToString(LargestInt value); +std::string JSON_API valueToString(LargestUInt value); +std::string JSON_API valueToString(double value); +std::string JSON_API valueToString(bool value); +std::string JSON_API valueToQuotedString(const char* value); + +/// \brief Output using the StyledStreamWriter. +/// \see Json::operator>>() +JSON_API std::ostream& operator<<(std::ostream&, const Value& root); + +} // namespace Json + +#if defined(JSONCPP_DISABLE_DLL_INTERFACE_WARNING) +#pragma warning(pop) +#endif // if defined(JSONCPP_DISABLE_DLL_INTERFACE_WARNING) + +#endif // JSON_WRITER_H_INCLUDED + +// ////////////////////////////////////////////////////////////////////// +// End of content of file: include/json/writer.h +// ////////////////////////////////////////////////////////////////////// + + + + + + +// ////////////////////////////////////////////////////////////////////// +// Beginning of content of file: include/json/assertions.h +// ////////////////////////////////////////////////////////////////////// + +// Copyright 2007-2010 Baptiste Lepilleur +// Distributed under MIT license, or public domain if desired and +// recognized in your jurisdiction. +// See file LICENSE for detail or copy at http://jsoncpp.sourceforge.net/LICENSE + +#ifndef CPPTL_JSON_ASSERTIONS_H_INCLUDED +#define CPPTL_JSON_ASSERTIONS_H_INCLUDED + +#include +#include + +#if !defined(JSON_IS_AMALGAMATION) +#include "config.h" +#endif // if !defined(JSON_IS_AMALGAMATION) + +/** It should not be possible for a maliciously designed file to + * cause an abort() or seg-fault, so these macros are used only + * for pre-condition violations and internal logic errors. + */ +#if JSON_USE_EXCEPTION + +// @todo <= add detail about condition in exception +# define JSON_ASSERT(condition) \ + {if (!(condition)) {Json::throwLogicError( "assert json failed" );}} + +# define JSON_FAIL_MESSAGE(message) \ + { \ + std::ostringstream oss; oss << message; \ + Json::throwLogicError(oss.str()); \ + abort(); \ + } + +#else // JSON_USE_EXCEPTION + +# define JSON_ASSERT(condition) assert(condition) + +// The call to assert() will show the failure message in debug builds. In +// release builds we abort, for a core-dump or debugger. +# define JSON_FAIL_MESSAGE(message) \ + { \ + std::ostringstream oss; oss << message; \ + assert(false && oss.str().c_str()); \ + abort(); \ + } + + +#endif + +#define JSON_ASSERT_MESSAGE(condition, message) \ + if (!(condition)) { \ + JSON_FAIL_MESSAGE(message); \ + } + +#endif // CPPTL_JSON_ASSERTIONS_H_INCLUDED + +// ////////////////////////////////////////////////////////////////////// +// End of content of file: include/json/assertions.h +// ////////////////////////////////////////////////////////////////////// + + + + + +#endif //ifndef JSON_AMALGATED_H_INCLUDED diff --git a/tools/onnx-subgraph/include/partition.h b/tools/onnx-subgraph/include/partition.h new file mode 100644 index 00000000000..77441f16fc6 --- /dev/null +++ b/tools/onnx-subgraph/include/partition.h @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PARTITION_H +#define PARTITION_H + +#include "onnx.pb.h" +#include +#include +#include +#include +#include +#include "device.h" +#include "graph.h" + +//deprecated +enum PartitionStrategy { + SPILTE_CPU_STRUCTURE_FIRST, + SPILTE_NPU_STRUCTURE_FIRST, + AUTOMATIC_SEARCH +}; + +class Partition { +private: + /* data */ +public: + Partition() {} + ~Partition() {} + /** + * @brief Partition the ONNX graph into subgraphs and produce cutting instructions. + * + * @param [in] g The ONNX graph to be partitioned. + * @param [in] d The device information for partitioning. + * @param [in] strategy The partition strategy to be used (deprecated). + * @param [in] node_io_size The input/output size information for each node. + * @pre The ONNX graph should be valid and the device information should be properly set. + * @post The graph is partitioned into subgraphs, and the results are stored in Subgraphs and otherSubgraphs. + * @exception None + * @return None + */ + void PartitionGraph(const onnx::GraphProto &g, Device& d, PartitionStrategy strategy, const std::unordered_map &node_io_size); +}; +#endif diff --git a/tools/onnx-subgraph/model_inference.py b/tools/onnx-subgraph/model_inference.py new file mode 100644 index 00000000000..e30a158107f --- /dev/null +++ b/tools/onnx-subgraph/model_inference.py @@ -0,0 +1,323 @@ +# Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity +from skimage.io import imread +import onnxruntime as ort +import numpy as np +import pandas as pd +import torch +import onnx +import pdb +import re +import os + +from quant import quant_conv_forward_save_output + +class ModelInference: + """ + This class is used to infer multiple onnx models. + Parameters: + model_path: Path to the model files. + subgraphsiostxt_path: Path to the txt file that describes the structure of the model graph. + Output: + outputs[0]: Inference result from the model. + Description: + Here, subgraphsiostxt_path is a txt file that describes the structure of the model graph and is used to get input/output node names. + The model_path contains paths to multiple onnx files. The load_sessions function will sort the onnx models in the model_path according to the order specified in subgraphsiostxt_path. + It then infers the sorted onnx models, returns the sessions data to self.sessions, and returns the sorted sequence to self.sorted_file_paths. + Finally, it infers the sessions based on the initial data provided by initial_input_data and returns the inference results. + """ + def __init__(self, model_path, subgraphsiostxt_path): + + self.model_path = model_path + self.subgraphsiostxt_path = subgraphsiostxt_path + self.sessions, self.sorted_file_paths = self.load_sessions() + + def load_sessions(self): + with open(self.subgraphsiostxt_path, 'r') as file: + content = file.read() + subgraph_order_map = {} + matches = re.findall(r'(\w+)subgraph(\d+): order(\d+)', content) + + for match in matches: + subgraph_type, subgraph_number, order = match + file_path = os.path.join(self.model_path, f"{subgraph_type}subgraph{subgraph_number}.onnx") + if int(order) in subgraph_order_map: + subgraph_order_map[int(order)].append(file_path) + else: + subgraph_order_map[int(order)] = [file_path] + + sorted_file_paths = [] + for order in sorted(subgraph_order_map.keys()): + sorted_file_paths.extend(subgraph_order_map[order]) + + sessions = [ort.InferenceSession(model) for model in sorted_file_paths] + return sessions, sorted_file_paths + def inference(self, initial_input_data): + input_data = initial_input_data + for i, (session,model_file) in enumerate(zip(self.sessions, self.sorted_file_paths)): + + input_names = [inp.name for inp in session.get_inputs()] + model_input_data = {name: input_data[name] for name in input_names} + outputs = session.run(None, model_input_data) + output_names = [out.name for out in session.get_outputs()] + + if i < len(self.sessions) - 1: + for output, output_name in zip(outputs, output_names): + input_data[output_name] = output + return outputs[0] + + def infer_single_onnx_model(model_file, input_data): + session = ort.InferenceSession(model_file) + outputs = session.run(None, input_data) + return outputs[0] + + + +class PcaInference: + """ + This class uses PCA for compression and inferring multiple ONNX models. + Parameters: + model_path: Path to the onnx model files. + subgraphsiostxt_path: Path to the txt file that describes the structure of the model graph. + endwithconv_path: Path to a txt file recording the onnx ending with convolution. + initial_input_data: Initial input data. + num: Inference times, providing the model name based on the number of times. + output_dir: Root directory for saving inference results. + Output: + outputs: Inference results. + Description: + A result_pt directory is generated in between to save intermediate results; however, not generating this directory does not affect experimental results. + The result folder saves the output of the convolution layer to calculate the compression rate. All results are saved in the output_dir folder. + """ + def __init__(self, model_path, subgraphsiostxt_path, endwithconv_path, output_dir): + self.model_path = model_path + self.subgraphsiostxt_path = subgraphsiostxt_path + self.endwithconv_path = endwithconv_path + self.output_dir = output_dir + ( + self.sessions, + self.conv_output_layer_map, + self.sorted_file_paths, + ) = self.load_sessions() + + def load_sessions(self): + with open(self.subgraphsiostxt_path, 'r') as file: + content = file.read() + subgraph_order_map = {} + matches = re.findall(r'(\w+)subgraph(\d+): order(\d+)', content) + + for match in matches: + subgraph_type, subgraph_number, order = match + file_path = os.path.join(self.model_path, f"{subgraph_type}subgraph{subgraph_number}.onnx") + if int(order) in subgraph_order_map: + subgraph_order_map[int(order)].append(file_path) + else: + subgraph_order_map[int(order)] = [file_path] + + sorted_file_paths = [] + for order in sorted(subgraph_order_map.keys()): + sorted_file_paths.extend(subgraph_order_map[order]) + + sessions = [] + conv_output_layer_map = {} + for model_file in sorted_file_paths: + session = ort.InferenceSession(model_file) + sessions.append(session) + + conv_outputs = {} + if self.onnx_end_conv(model_file): + model = onnx.load(model_file) + for idx, node in enumerate(model.graph.node): + if node.op_type == 'Conv': + for output_name in node.output: + if output_name not in conv_outputs: + conv_outputs[output_name] = idx + 1 + conv_output_layer_map[model_file] = conv_outputs + + return sessions, conv_output_layer_map, sorted_file_paths + def load_onnx_dict(self): + onnx_dict = [] + with open(self.endwithconv_path, 'r') as file: + content = file.read() + numbers = re.findall(r'\b\d+\b', content) + for number in numbers: + onnx_path = os.path.join(self.model_path, f"NPUsubgraph{number}.onnx") + onnx_dict.append(onnx_path) + return onnx_dict + def onnx_end_conv(self, model_file): + for onnx in self.load_onnx_dict(): + if onnx == model_file: + return True + return False + + + def check_and_convert_inputs(self,model_input_data): + for key, value in model_input_data.items(): + if isinstance(value, torch.Tensor): + model_input_data[key] = value.numpy() + elif not isinstance(value, np.ndarray): + raise TypeError(f"Input data for '{key}' is not a NumPy array. Got type: {type(value)}") + return model_input_data + + def decomp(self,compressed_tensor, ru, rbits, num_bits=8): + decompressed_tensor = torch.dequantize(compressed_tensor) + decompressed_tensor = decompressed_tensor.numpy() + if not isinstance(decompressed_tensor, np.ndarray): + raise TypeError("The decompressed tensor is not a NumPy array.") + return decompressed_tensor + + def inference(self, initial_input_data, num): + input_data = initial_input_data + aux_data = {} + record_model_name = None + + for i, (session, model_file) in enumerate(zip(self.sessions, self.sorted_file_paths)): + input_names = [inp.name for inp in session.get_inputs()] + + if self.onnx_end_conv(record_model_name): + for name in input_names: + if name in input_data and name in aux_data: + compressed_tensor = input_data[name] + ru, rbits = aux_data[name] + decompressed_tensor = self.decomp(compressed_tensor, ru, rbits) + input_data[name] = decompressed_tensor + + model_input_data = {name: input_data[name] for name in input_names} + self.check_and_convert_inputs(model_input_data) + outputs = session.run(None, model_input_data) + output_names = [out.name for out in session.get_outputs()] + conv_outputs = self.conv_output_layer_map.get(model_file, {}) + + for output_name, output in zip(output_names, outputs): + if output_name in conv_outputs: + output_tensor = torch.tensor(output) + layer = conv_outputs[output_name] + output_tensor = quant_conv_forward_save_output(output_tensor, layer, count=1, bit=8, i=num, output_dir=self.output_dir) + input_data[output_name] = output_tensor + else: + input_data[output_name] = output + record_model_name = model_file + + return outputs[0] + + + +class ImageMetricsEvaluator: + """ + Used to evaluate image quality, including MSE, PSNR, and SSIM. + + Parameters: + original_dir (str): Directory containing the original images. + generated_dir (str): Directory containing the generated images. + compression_dir (str): Directory containing the compression information text files. + Output: + output_file (str): Path to the output file (Excel). + """ + def __init__(self, original_dir, generated_dir, compression_dir, output_file): + + self.original_dir = original_dir + self.generated_dir = generated_dir + self.compression_dir = compression_dir + self.output_file = output_file + + def calculate_image_metrics(self, original_image_path, generated_image_path): + """Calculate MSE, PSNR, and SSIM between the given original and generated images.""" + original_image = imread(original_image_path) + generated_image = imread(generated_image_path) + + if original_image.shape != generated_image.shape: + raise ValueError('两个图像的尺寸必须相同') + + mse = mean_squared_error(original_image, generated_image) + psnr = peak_signal_noise_ratio(original_image, generated_image) + + min_dim = min(original_image.shape[:2]) + win_size = min(7, min_dim) + if win_size % 2 == 0: + win_size -= 1 + if win_size < 3: + win_size = 3 + + ssim = structural_similarity(original_image, generated_image, multichannel=True, win_size=win_size, channel_axis=-1) + + return mse, psnr, ssim + + def calculate_compression_rate(self, file_path): + """Read from a specified text file and calculate the average compression rate.""" + with open(file_path) as f: + lines = f.readlines() + rate_all = sum(float(line.split(',')[0]) * float(line.split(',')[1]) for line in lines) + all_ = sum(float(line.split(',')[1]) for line in lines) + return rate_all / all_ if all_ != 0 else None + + def find_matching_compression_file(self, image_name): + """Find the corresponding compression info file based on the image filename.""" + base_name, _ = os.path.splitext(image_name) + number = re.search(r'_(\d+)', base_name) + if number: + number = number.group(1) + compression_files = [f for f in os.listdir(self.compression_dir) if f.startswith(f'result_{number}') and f.endswith('.txt')] + if compression_files: + return os.path.join(self.compression_dir, compression_files[0]) + return None + def compare_images_in_directories(self): + """Compare all images in two directories and save the results to an Excel file.""" + def sort_key(filename): + parts = filename.split('_') + try: + return int(parts[1].split('.')[0]) if len(parts) > 1 else 0 + except (ValueError, IndexError): + print(f"Warning: Could not parse number from filename {filename}") + return 0 + + original_images = sorted([f for f in os.listdir(self.original_dir) if f.endswith('.png')], key=sort_key) + generated_images = sorted([f for f in os.listdir(self.generated_dir) if f.endswith('.png')], key=sort_key) + + results = [] + + for orig_img_name, gen_img_name in zip(original_images, generated_images): + orig_img_path = os.path.join(self.original_dir, orig_img_name) + gen_img_path = os.path.join(self.generated_dir, gen_img_name) + + try: + mse, psnr, ssim = self.calculate_image_metrics(orig_img_path, gen_img_path) + compression_file_path = self.find_matching_compression_file(orig_img_name) + compression_rate = self.calculate_compression_rate( compression_file_path) if compression_file_path else None + results.append({ + 'Original Image': orig_img_name, + 'Generated Image': gen_img_name, + 'MSE': mse, + 'PSNR': psnr, + 'SSIM': ssim, + 'Compression Rate': compression_rate + }) + except Exception as e: + print(f"Error processing images {orig_img_name} and {gen_img_name}: {e}") + + df = pd.DataFrame(results) + + output_dir = os.path.dirname(self.output_file) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + try: + df.to_excel(self.output_file, index=False) + print(f'Results have been saved to {self.output_file}') + except PermissionError: + print(f"Permission denied: Unable to write to {self.output_file}. Please check file permissions or close the file if it is open in another program.") + except Exception as e: + print(f"An error occurred while saving the results: {e}") + diff --git a/tools/onnx-subgraph/model_inference_multiple_output.py b/tools/onnx-subgraph/model_inference_multiple_output.py new file mode 100644 index 00000000000..2afa8e8cffb --- /dev/null +++ b/tools/onnx-subgraph/model_inference_multiple_output.py @@ -0,0 +1,327 @@ +# Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity +from skimage.io import imread +import onnxruntime as ort +import numpy as np +import pandas as pd +import torch +import onnx +import pdb +import re +import os + +from quant import quant_conv_forward_save_output + +class ModelInference: + """ + This class is used to infer multiple onnx models. + Parameters: + model_path: Path to the model files. + subgraphsiostxt_path: Path to the txt file that describes the structure of the model graph. + Output: + outputs[0]: Inference result from the model. + Description: + Here, subgraphsiostxt_path is a txt file that describes the structure of the model graph and is used to get input/output node names. + The model_path contains paths to multiple onnx files. The load_sessions function will sort the onnx models in the model_path according to the order specified in subgraphsiostxt_path. + It then infers the sorted onnx models, returns the sessions data to self.sessions, and returns the sorted sequence to self.sorted_file_paths. + Finally, it infers the sessions based on the initial data provided by initial_input_data and returns the inference results. + """ + def __init__(self, model_path, subgraphsiostxt_path): + + self.model_path = model_path + self.subgraphsiostxt_path = subgraphsiostxt_path + self.sessions, self.sorted_file_paths = self.load_sessions() + + def load_sessions(self): + with open(self.subgraphsiostxt_path, 'r') as file: + content = file.read() + subgraph_order_map = {} + matches = re.findall(r'(\w+)subgraph(\d+): order(\d+)', content) + + for match in matches: + subgraph_type, subgraph_number, order = match + # lower_subgraph_type = subgraph_type.lower() + file_path = os.path.join(self.model_path, f"{subgraph_type}subgraph{subgraph_number}.onnx") + if int(order) in subgraph_order_map: + subgraph_order_map[int(order)].append(file_path) + else: + subgraph_order_map[int(order)] = [file_path] + + sorted_file_paths = [] + for order in sorted(subgraph_order_map.keys()): + sorted_file_paths.extend(subgraph_order_map[order]) + + sessions = [ort.InferenceSession(model) for model in sorted_file_paths] + return sessions, sorted_file_paths + def inference(self, initial_input_data, output_names_to_collect=None): + input_data = initial_input_data + collected_outputs = {} + + for i, (session, model_file) in enumerate(zip(self.sessions, self.sorted_file_paths)): + input_names = [inp.name for inp in session.get_inputs()] + output_names = [out.name for out in session.get_outputs()] + model_input_data = {name: input_data[name] for name in input_names} + outputs = session.run(None, model_input_data) + current_model_outputs = dict(zip(output_names, outputs)) + if output_names_to_collect is not None: + for output_name in output_names_to_collect: + if output_name in current_model_outputs: + collected_outputs[output_name] = current_model_outputs[output_name] + + if i < len(self.sessions) - 1: + input_data.update(current_model_outputs) + return collected_outputs + + def infer_single_onnx_model(model_file, input_data): + session = ort.InferenceSession(model_file) + outputs = session.run(None, input_data) + output_names = [output.name for output in session.get_outputs()] + output_dict = {name: output for name, output in zip(output_names, outputs)} + return output_dict + + + +class PcaInference: + """ + This class uses PCA for compression and inferring multiple ONNX models. + Parameters: + model_path: Path to the onnx model files. + subgraphsiostxt_path: Path to the txt file that describes the structure of the model graph. + endwithconv_path: Path to a txt file recording the onnx ending with convolution. + initial_input_data: Initial input data. + num: Inference times, providing the model name based on the number of times. + output_dir: Root directory for saving inference results. + Output: + outputs: Inference results. + Description: + A result_pt directory is generated in between to save intermediate results; however, not generating this directory does not affect experimental results. + The result folder saves the output of the convolution layer to calculate the compression rate. All results are saved in the output_dir folder. + """ + def __init__(self, model_path, subgraphsiostxt_path, endwithconv_path, output_dir): + self.model_path = model_path + self.subgraphsiostxt_path = subgraphsiostxt_path + self.endwithconv_path = endwithconv_path + self.output_dir = output_dir + ( + self.sessions, + self.conv_output_layer_map, + self.sorted_file_paths, + ) = self.load_sessions() + + def load_sessions(self): + with open(self.subgraphsiostxt_path, 'r') as file: + content = file.read() + subgraph_order_map = {} + matches = re.findall(r'(\w+)subgraph(\d+): order(\d+)', content) + + for match in matches: + subgraph_type, subgraph_number, order = match + file_path = os.path.join(self.model_path, f"{subgraph_type}subgraph{subgraph_number}.onnx") + if int(order) in subgraph_order_map: + subgraph_order_map[int(order)].append(file_path) + else: + subgraph_order_map[int(order)] = [file_path] + + sorted_file_paths = [] + for order in sorted(subgraph_order_map.keys()): + sorted_file_paths.extend(subgraph_order_map[order]) + + sessions = [] + conv_output_layer_map = {} + for model_file in sorted_file_paths: + session = ort.InferenceSession(model_file) + sessions.append(session) + + conv_outputs = {} + if self.onnx_end_conv(model_file): + model = onnx.load(model_file) + for idx, node in enumerate(model.graph.node): + if node.op_type == 'Conv': + for output_name in node.output: + if output_name not in conv_outputs: + conv_outputs[output_name] = idx + 1 + conv_output_layer_map[model_file] = conv_outputs + + return sessions, conv_output_layer_map, sorted_file_paths + def load_onnx_dict(self): + onnx_dict = [] + with open(self.endwithconv_path, 'r') as file: + content = file.read() + numbers = re.findall(r'\b\d+\b', content) + for number in numbers: + onnx_path = os.path.join(self.model_path, f"NPUsubgraph{number}.onnx") + onnx_dict.append(onnx_path) + return onnx_dict + def onnx_end_conv(self, model_file): + for onnx in self.load_onnx_dict(): + if onnx == model_file: + return True + return False + + + def check_and_convert_inputs(self,model_input_data): + for key, value in model_input_data.items(): + if isinstance(value, torch.Tensor): + model_input_data[key] = value.numpy() + elif not isinstance(value, np.ndarray): + raise TypeError(f"Input data for '{key}' is not a NumPy array. Got type: {type(value)}") + return model_input_data + + def decomp(self,compressed_tensor, ru, rbits, num_bits=8): + decompressed_tensor = torch.dequantize(compressed_tensor) + decompressed_tensor = decompressed_tensor.numpy() + if not isinstance(decompressed_tensor, np.ndarray): + raise TypeError("The decompressed tensor is not a NumPy array.") + return decompressed_tensor + + def inference(self, initial_input_data, num): + input_data = initial_input_data + aux_data = {} + record_model_name = None + + for i, (session, model_file) in enumerate(zip(self.sessions, self.sorted_file_paths)): + input_names = [inp.name for inp in session.get_inputs()] + + if self.onnx_end_conv(record_model_name): + for name in input_names: + if name in input_data and name in aux_data: + compressed_tensor = input_data[name] + ru, rbits = aux_data[name] + decompressed_tensor = self.decomp(compressed_tensor, ru, rbits) + input_data[name] = decompressed_tensor + + model_input_data = {name: input_data[name] for name in input_names} + self.check_and_convert_inputs(model_input_data) + outputs = session.run(None, model_input_data) + output_names = [out.name for out in session.get_outputs()] + conv_outputs = self.conv_output_layer_map.get(model_file, {}) + + for output_name, output in zip(output_names, outputs): + if output_name in conv_outputs: + output_tensor = torch.tensor(output) + layer = conv_outputs[output_name] + output_tensor = quant_conv_forward_save_output(output_tensor, layer, count=1, bit=8, i=num, output_dir=self.output_dir) + input_data[output_name] = output_tensor + else: + input_data[output_name] = output + record_model_name = model_file + + return outputs[0] + + + +class ImageMetricsEvaluator: + """ + Used to evaluate image quality, including MSE, PSNR, and SSIM. + + Parameters: + original_dir (str): Directory containing the original images. + generated_dir (str): Directory containing the generated images. + compression_dir (str): Directory containing the compression information text files. + Output: + output_file (str): Path to the output file (Excel). + """ + def __init__(self, original_dir, generated_dir, compression_dir, output_file): + + self.original_dir = original_dir + self.generated_dir = generated_dir + self.compression_dir = compression_dir + self.output_file = output_file + + def calculate_image_metrics(self, original_image_path, generated_image_path): + original_image = imread(original_image_path) + generated_image = imread(generated_image_path) + + if original_image.shape != generated_image.shape: + raise ValueError('两个图像的尺寸必须相同') + + mse = mean_squared_error(original_image, generated_image) + psnr = peak_signal_noise_ratio(original_image, generated_image) + + min_dim = min(original_image.shape[:2]) + win_size = min(7, min_dim) + if win_size % 2 == 0: + win_size -= 1 + if win_size < 3: + win_size = 3 + + ssim = structural_similarity(original_image, generated_image, multichannel=True, win_size=win_size, channel_axis=-1) + + return mse, psnr, ssim + + def calculate_compression_rate(self, file_path): + with open(file_path) as f: + lines = f.readlines() + rate_all = sum(float(line.split(',')[0]) * float(line.split(',')[1]) for line in lines) + all_ = sum(float(line.split(',')[1]) for line in lines) + return rate_all / all_ if all_ != 0 else None + + def find_matching_compression_file(self, image_name): + base_name, _ = os.path.splitext(image_name) + number = re.search(r'_(\d+)', base_name) + if number: + number = number.group(1) + compression_files = [f for f in os.listdir(self.compression_dir) if f.startswith(f'result_{number}') and f.endswith('.txt')] + if compression_files: + return os.path.join(self.compression_dir, compression_files[0]) + return None + def compare_images_in_directories(self): + def sort_key(filename): + parts = filename.split('_') + try: + return int(parts[1].split('.')[0]) if len(parts) > 1 else 0 + except (ValueError, IndexError): + print(f"Warning: Could not parse number from filename {filename}") + return 0 + + original_images = sorted([f for f in os.listdir(self.original_dir) if f.endswith('.png')], key=sort_key) + generated_images = sorted([f for f in os.listdir(self.generated_dir) if f.endswith('.png')], key=sort_key) + + results = [] + + for orig_img_name, gen_img_name in zip(original_images, generated_images): + orig_img_path = os.path.join(self.original_dir, orig_img_name) + gen_img_path = os.path.join(self.generated_dir, gen_img_name) + + try: + mse, psnr, ssim = self.calculate_image_metrics(orig_img_path, gen_img_path) + compression_file_path = self.find_matching_compression_file(orig_img_name) + compression_rate = self.calculate_compression_rate( compression_file_path) if compression_file_path else None + results.append({ + 'Original Image': orig_img_name, + 'Generated Image': gen_img_name, + 'MSE': mse, + 'PSNR': psnr, + 'SSIM': ssim, + 'Compression Rate': compression_rate + }) + except Exception as e: + print(f"Error processing images {orig_img_name} and {gen_img_name}: {e}") + + df = pd.DataFrame(results) + + output_dir = os.path.dirname(self.output_file) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + try: + df.to_excel(self.output_file, index=False) + print(f'Results have been saved to {self.output_file}') + except PermissionError: + print(f"Permission denied: Unable to write to {self.output_file}. Please check file permissions or close the file if it is open in another program.") + except Exception as e: + print(f"An error occurred while saving the results: {e}") + diff --git a/tools/onnx-subgraph/onnx.proto b/tools/onnx-subgraph/onnx.proto new file mode 100644 index 00000000000..6a3abfdd109 --- /dev/null +++ b/tools/onnx-subgraph/onnx.proto @@ -0,0 +1,871 @@ +// +// WARNING: This file is automatically generated! Please edit onnx.in.proto. +// + + +// SPDX-License-Identifier: Apache-2.0 + + +syntax = "proto2"; + +package onnx; + +// Overview +// +// ONNX is an open specification that is comprised of the following components: +// +// 1) A definition of an extensible computation graph model. +// 2) Definitions of standard data types. +// 3) Definitions of built-in operators. +// +// This document describes the syntax of models and their computation graphs, +// as well as the standard data types. Together, they are referred to as the ONNX +// Intermediate Representation, or 'IR' for short. +// +// The normative semantic specification of the ONNX IR is found in docs/IR.md. +// Definitions of the built-in neural network operators may be found in docs/Operators.md. + +// Notes +// +// Protobuf compatibility +// +// To simplify framework compatibility, ONNX is defined using the subset of protobuf +// that is compatible with both protobuf v2 and v3. This means that we do not use any +// protobuf features that are only available in one of the two versions. +// +// Here are the most notable contortions we have to carry out to work around +// these limitations: +// +// - No 'map' (added protobuf 3.0). We instead represent mappings as lists +// of key-value pairs, where order does not matter and duplicates +// are not allowed. + + +// Versioning +// +// ONNX versioning is specified in docs/IR.md and elaborated on in docs/Versioning.md +// +// To be compatible with both proto2 and proto3, we will use a version number +// that is not defined by the default value but an explicit enum number. +enum Version { + // proto3 requires the first enum value to be zero. + // We add this just to appease the compiler. + _START_VERSION = 0; + // The version field is always serialized and we will use it to store the + // version that the graph is generated from. This helps us set up version + // control. + // For the IR, we are using simple numbers starting with 0x00000001, + // which was the version we published on Oct 10, 2017. + IR_VERSION_2017_10_10 = 0x0000000000000001; + + // IR_VERSION 2 published on Oct 30, 2017 + // - Added type discriminator to AttributeProto to support proto3 users + IR_VERSION_2017_10_30 = 0x0000000000000002; + + // IR VERSION 3 published on Nov 3, 2017 + // - For operator versioning: + // - Added new message OperatorSetIdProto + // - Added opset_import in ModelProto + // - For vendor extensions, added domain in NodeProto + IR_VERSION_2017_11_3 = 0x0000000000000003; + + // IR VERSION 4 published on Jan 22, 2019 + // - Relax constraint that initializers should be a subset of graph inputs + // - Add type BFLOAT16 + IR_VERSION_2019_1_22 = 0x0000000000000004; + + // IR VERSION 5 published on March 18, 2019 + // - Add message TensorAnnotation. + // - Add quantization annotation in GraphProto to map tensor with its scale and zero point quantization parameters. + IR_VERSION_2019_3_18 = 0x0000000000000005; + + // IR VERSION 6 published on Sep 19, 2019 + // - Add support for sparse tensor constants stored in model. + // - Add message SparseTensorProto + // - Add sparse initializers + IR_VERSION_2019_9_19 = 0x0000000000000006; + + // IR VERSION 7 published on May 8, 2020 + // - Add support to allow function body graph to rely on multiple external opreator sets. + // - Add a list to promote inference graph's initializers to global and + // mutable variables. Global variables are visible in all graphs of the + // stored models. + // - Add message TrainingInfoProto to store initialization + // method and training algorithm. The execution of TrainingInfoProto + // can modify the values of mutable variables. + // - Implicitly add inference graph into each TrainingInfoProto's algorithm. + IR_VERSION_2020_5_8 = 0x0000000000000007; + + // IR VERSION 8 published on July 30, 2021 + // Introduce TypeProto.SparseTensor + // Introduce TypeProto.Optional + // Added a list of FunctionProtos local to the model + // Deprecated since_version and operator status from FunctionProto + IR_VERSION_2021_7_30 = 0x0000000000000008; + + // IR VERSION 9 published on May 5, 2023 + // Added AttributeProto to FunctionProto so that default attribute values can be set. + // Added FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ. + IR_VERSION_2023_5_5 = 0x0000000000000009; + + // IR VERSION 10 published on TBD + // Added UINT4, INT4. + IR_VERSION = 0x000000000000000A; +} + +// Attributes +// +// A named attribute containing either singular float, integer, string, graph, +// and tensor values, or repeated float, integer, string, graph, and tensor values. +// An AttributeProto MUST contain the name field, and *only one* of the +// following content fields, effectively enforcing a C/C++ union equivalent. +message AttributeProto { + reserved 12, 16 to 19; + reserved "v"; + + // Note: this enum is structurally identical to the OpSchema::AttrType + // enum defined in schema.h. If you rev one, you likely need to rev the other. + enum AttributeType { + UNDEFINED = 0; + FLOAT = 1; + INT = 2; + STRING = 3; + TENSOR = 4; + GRAPH = 5; + SPARSE_TENSOR = 11; + TYPE_PROTO = 13; + + FLOATS = 6; + INTS = 7; + STRINGS = 8; + TENSORS = 9; + GRAPHS = 10; + SPARSE_TENSORS = 12; + TYPE_PROTOS = 14; + } + + // The name field MUST be present for this version of the IR. + optional string name = 1; // namespace Attribute + + // if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function. + // In this case, this AttributeProto does not contain data, and it's a reference of attribute + // in parent scope. + // NOTE: This should ONLY be used in function (sub-graph). It's invalid to be used in main graph. + optional string ref_attr_name = 21; + + // A human-readable documentation for this attribute. Markdown is allowed. + optional string doc_string = 13; + + // The type field MUST be present for this version of the IR. + // For 0.0.1 versions of the IR, this field was not defined, and + // implementations needed to use has_field heuristics to determine + // which value field was in use. For IR_VERSION 0.0.2 or later, this + // field MUST be set and match the f|i|s|t|... field in use. This + // change was made to accommodate proto3 implementations. + optional AttributeType type = 20; // discriminator that indicates which field below is in use + + // Exactly ONE of the following fields must be present for this version of the IR + optional float f = 2; // float + optional int64 i = 3; // int + optional bytes s = 4; // UTF-8 string + optional TensorProto t = 5; // tensor value + optional GraphProto g = 6; // graph + optional SparseTensorProto sparse_tensor = 22; // sparse tensor value + // Do not use field below, it's deprecated. + // optional ValueProto v = 12; // value - subsumes everything but graph + optional TypeProto tp = 14; // type proto + + repeated float floats = 7; // list of floats + repeated int64 ints = 8; // list of ints + repeated bytes strings = 9; // list of UTF-8 strings + repeated TensorProto tensors = 10; // list of tensors + repeated GraphProto graphs = 11; // list of graph + repeated SparseTensorProto sparse_tensors = 23; // list of sparse tensors + repeated TypeProto type_protos = 15;// list of type protos +} + +// Defines information on value, including the name, the type, and +// the shape of the value. +message ValueInfoProto { + // This field MUST be present in this version of the IR. + optional string name = 1; // namespace Value + // This field MUST be present in this version of the IR for + // inputs and outputs of the top-level graph. + optional TypeProto type = 2; + // A human-readable documentation for this value. Markdown is allowed. + optional string doc_string = 3; + // Named metadata values; keys should be distinct. + repeated StringStringEntryProto metadata_props = 4; +} + +// Nodes +// +// Computation graphs are made up of a DAG of nodes, which represent what is +// commonly called a "layer" or "pipeline stage" in machine learning frameworks. +// +// For example, it can be a node of type "Conv" that takes in an image, a filter +// tensor and a bias tensor, and produces the convolved output. +message NodeProto { + repeated string input = 1; // namespace Value + repeated string output = 2; // namespace Value + + // An optional identifier for this node in a graph. + // This field MAY be absent in this version of the IR. + optional string name = 3; // namespace Node + + // The symbolic identifier of the Operator to execute. + optional string op_type = 4; // namespace Operator + // The domain of the OperatorSet that specifies the operator named by op_type. + optional string domain = 7; // namespace Domain + // Overload identifier, used only to map this to a model-local function. + optional string overload = 8; + + // Additional named attributes. + repeated AttributeProto attribute = 5; + + // A human-readable documentation for this node. Markdown is allowed. + optional string doc_string = 6; + + // Named metadata values; keys should be distinct. + repeated StringStringEntryProto metadata_props = 9; +} + +// Training information +// TrainingInfoProto stores information for training a model. +// In particular, this defines two functionalities: an initialization-step +// and a training-algorithm-step. Initialization resets the model +// back to its original state as if no training has been performed. +// Training algorithm improves the model based on input data. +// +// The semantics of the initialization-step is that the initializers +// in ModelProto.graph and in TrainingInfoProto.algorithm are first +// initialized as specified by the initializers in the graph, and then +// updated by the "initialization_binding" in every instance in +// ModelProto.training_info. +// +// The field "algorithm" defines a computation graph which represents a +// training algorithm's step. After the execution of a +// TrainingInfoProto.algorithm, the initializers specified by "update_binding" +// may be immediately updated. If the targeted training algorithm contains +// consecutive update steps (such as block coordinate descent methods), +// the user needs to create a TrainingInfoProto for each step. +message TrainingInfoProto { + // This field describes a graph to compute the initial tensors + // upon starting the training process. Initialization graph has no input + // and can have multiple outputs. Usually, trainable tensors in neural + // networks are randomly initialized. To achieve that, for each tensor, + // the user can put a random number operator such as RandomNormal or + // RandomUniform in TrainingInfoProto.initialization.node and assign its + // random output to the specific tensor using "initialization_binding". + // This graph can also set the initializers in "algorithm" in the same + // TrainingInfoProto; a use case is resetting the number of training + // iteration to zero. + // + // By default, this field is an empty graph and its evaluation does not + // produce any output. Thus, no initializer would be changed by default. + optional GraphProto initialization = 1; + + // This field represents a training algorithm step. Given required inputs, + // it computes outputs to update initializers in its own or inference graph's + // initializer lists. In general, this field contains loss node, gradient node, + // optimizer node, increment of iteration count. + // + // An execution of the training algorithm step is performed by executing the + // graph obtained by combining the inference graph (namely "ModelProto.graph") + // and the "algorithm" graph. That is, the actual + // input/initializer/output/node/value_info/sparse_initializer list of + // the training graph is the concatenation of + // "ModelProto.graph.input/initializer/output/node/value_info/sparse_initializer" + // and "algorithm.input/initializer/output/node/value_info/sparse_initializer" + // in that order. This combined graph must satisfy the normal ONNX conditions. + // Now, let's provide a visualization of graph combination for clarity. + // Let the inference graph (i.e., "ModelProto.graph") be + // tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d + // and the "algorithm" graph be + // tensor_d -> Add -> tensor_e + // The combination process results + // tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d -> Add -> tensor_e + // + // Notice that an input of a node in the "algorithm" graph may reference the + // output of a node in the inference graph (but not the other way round). Also, inference + // node cannot reference inputs of "algorithm". With these restrictions, inference graph + // can always be run independently without training information. + // + // By default, this field is an empty graph and its evaluation does not + // produce any output. Evaluating the default training step never + // update any initializers. + optional GraphProto algorithm = 2; + + // This field specifies the bindings from the outputs of "initialization" to + // some initializers in "ModelProto.graph.initializer" and + // the "algorithm.initializer" in the same TrainingInfoProto. + // See "update_binding" below for details. + // + // By default, this field is empty and no initializer would be changed + // by the execution of "initialization". + repeated StringStringEntryProto initialization_binding = 3; + + // Gradient-based training is usually an iterative procedure. In one gradient + // descent iteration, we apply + // + // x = x - r * g + // + // where "x" is the optimized tensor, "r" stands for learning rate, and "g" is + // gradient of "x" with respect to a chosen loss. To avoid adding assignments + // into the training graph, we split the update equation into + // + // y = x - r * g + // x = y + // + // The user needs to save "y = x - r * g" into TrainingInfoProto.algorithm. To + // tell that "y" should be assigned to "x", the field "update_binding" may + // contain a key-value pair of strings, "x" (key of StringStringEntryProto) + // and "y" (value of StringStringEntryProto). + // For a neural network with multiple trainable (mutable) tensors, there can + // be multiple key-value pairs in "update_binding". + // + // The initializers appears as keys in "update_binding" are considered + // mutable variables. This implies some behaviors + // as described below. + // + // 1. We have only unique keys in all "update_binding"s so that two + // variables may not have the same name. This ensures that one + // variable is assigned up to once. + // 2. The keys must appear in names of "ModelProto.graph.initializer" or + // "TrainingInfoProto.algorithm.initializer". + // 3. The values must be output names of "algorithm" or "ModelProto.graph.output". + // 4. Mutable variables are initialized to the value specified by the + // corresponding initializer, and then potentially updated by + // "initializer_binding"s and "update_binding"s in "TrainingInfoProto"s. + // + // This field usually contains names of trainable tensors + // (in ModelProto.graph), optimizer states such as momentums in advanced + // stochastic gradient methods (in TrainingInfoProto.graph), + // and number of training iterations (in TrainingInfoProto.graph). + // + // By default, this field is empty and no initializer would be changed + // by the execution of "algorithm". + repeated StringStringEntryProto update_binding = 4; +} + +// Models +// +// ModelProto is a top-level file/container format for bundling a ML model and +// associating its computation graph with metadata. +// +// The semantics of the model are described by the associated GraphProto's. +message ModelProto { + // The version of the IR this model targets. See Version enum above. + // This field MUST be present. + optional int64 ir_version = 1; + + // The OperatorSets this model relies on. + // All ModelProtos MUST have at least one entry that + // specifies which version of the ONNX OperatorSet is + // being imported. + // + // All nodes in the ModelProto's graph will bind against the operator + // with the same-domain/same-op_type operator with the HIGHEST version + // in the referenced operator sets. + repeated OperatorSetIdProto opset_import = 8; + + // The name of the framework or tool used to generate this model. + // This field SHOULD be present to indicate which implementation/tool/framework + // emitted the model. + optional string producer_name = 2; + + // The version of the framework or tool used to generate this model. + // This field SHOULD be present to indicate which implementation/tool/framework + // emitted the model. + optional string producer_version = 3; + + // Domain name of the model. + // We use reverse domain names as name space indicators. For example: + // `com.facebook.fair` or `com.microsoft.cognitiveservices` + // + // Together with `model_version` and GraphProto.name, this forms the unique identity of + // the graph. + optional string domain = 4; + + // The version of the graph encoded. See Version enum below. + optional int64 model_version = 5; + + // A human-readable documentation for this model. Markdown is allowed. + optional string doc_string = 6; + + // The parameterized graph that is evaluated to execute the model. + optional GraphProto graph = 7; + + // Named metadata values; keys should be distinct. + repeated StringStringEntryProto metadata_props = 14; + + // Training-specific information. Sequentially executing all stored + // `TrainingInfoProto.algorithm`s and assigning their outputs following + // the corresponding `TrainingInfoProto.update_binding`s is one training + // iteration. Similarly, to initialize the model + // (as if training hasn't happened), the user should sequentially execute + // all stored `TrainingInfoProto.initialization`s and assigns their outputs + // using `TrainingInfoProto.initialization_binding`s. + // + // If this field is empty, the training behavior of the model is undefined. + repeated TrainingInfoProto training_info = 20; + + // A list of function protos local to the model. + // + // The (domain, name, overload) tuple must be unique across the function protos in this list. + // In case of any conflicts the behavior (whether the model local functions are given higher priority, + // or standard operator sets are given higher priotity or this is treated as error) is defined by + // the runtimes. + // + // The operator sets imported by FunctionProto should be compatible with the ones + // imported by ModelProto and other model local FunctionProtos. + // Example, if same operator set say 'A' is imported by a FunctionProto and ModelProto + // or by 2 FunctionProtos then versions for the operator set may be different but, + // the operator schema returned for op_type, domain, version combination + // for both the versions should be same for every node in the function body. + // + // One FunctionProto can reference other FunctionProto in the model, however, recursive reference + // is not allowed. + repeated FunctionProto functions = 25; +}; + +// StringStringEntryProto follows the pattern for cross-proto-version maps. +// See https://developers.google.com/protocol-buffers/docs/proto3#maps +message StringStringEntryProto { + optional string key = 1; + optional string value = 2; +}; + +message TensorAnnotation { + optional string tensor_name = 1; + // pairs to annotate tensor specified by above. + // The keys used in the mapping below must be pre-defined in ONNX spec. + // For example, for 8-bit linear quantization case, 'SCALE_TENSOR', 'ZERO_POINT_TENSOR' will be pre-defined as + // quantization parameter keys. + repeated StringStringEntryProto quant_parameter_tensor_names = 2; +} + + + +// Graphs +// +// A graph defines the computational logic of a model and is comprised of a parameterized +// list of nodes that form a directed acyclic graph based on their inputs and outputs. +// This is the equivalent of the "network" or "graph" in many deep learning +// frameworks. +message GraphProto { + // The nodes in the graph, sorted topologically. + repeated NodeProto node = 1; + + // The name of the graph. + optional string name = 2; // namespace Graph + + // A list of named tensor values, used to specify constant inputs of the graph. + // Each initializer (both TensorProto as well SparseTensorProto) MUST have a name. + // The name MUST be unique across both initializer and sparse_initializer, + // but the name MAY also appear in the input list. + repeated TensorProto initializer = 5; + + // Initializers (see above) stored in sparse format. + repeated SparseTensorProto sparse_initializer = 15; + + // A human-readable documentation for this graph. Markdown is allowed. + optional string doc_string = 10; + + // The inputs and outputs of the graph. + repeated ValueInfoProto input = 11; + repeated ValueInfoProto output = 12; + + // Information for the values in the graph. The ValueInfoProto.name's + // must be distinct. It is optional for a value to appear in value_info list. + repeated ValueInfoProto value_info = 13; + + // This field carries information to indicate the mapping among a tensor and its + // quantization parameter tensors. For example: + // For tensor 'a', it may have {'SCALE_TENSOR', 'a_scale'} and {'ZERO_POINT_TENSOR', 'a_zero_point'} annotated, + // which means, tensor 'a_scale' and tensor 'a_zero_point' are scale and zero point of tensor 'a' in the model. + repeated TensorAnnotation quantization_annotation = 14; + + // Named metadata values; keys should be distinct. + repeated StringStringEntryProto metadata_props = 16; + + reserved 3, 4, 6 to 9; + reserved "ir_version", "producer_version", "producer_tag", "domain"; +} + +// Tensors +// +// A serialized tensor value. +message TensorProto { + enum DataType { + UNDEFINED = 0; + // Basic types. + FLOAT = 1; // float + UINT8 = 2; // uint8_t + INT8 = 3; // int8_t + UINT16 = 4; // uint16_t + INT16 = 5; // int16_t + INT32 = 6; // int32_t + INT64 = 7; // int64_t + STRING = 8; // string + BOOL = 9; // bool + + // IEEE754 half-precision floating-point format (16 bits wide). + // This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits. + FLOAT16 = 10; + + DOUBLE = 11; + UINT32 = 12; + UINT64 = 13; + COMPLEX64 = 14; // complex with float32 real and imaginary components + COMPLEX128 = 15; // complex with float64 real and imaginary components + + // Non-IEEE floating-point format based on IEEE754 single-precision + // floating-point number truncated to 16 bits. + // This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits. + BFLOAT16 = 16; + + // Non-IEEE floating-point format based on papers + // FP8 Formats for Deep Learning, https://arxiv.org/abs/2209.05433, + // 8-bit Numerical Formats For Deep Neural Networks, https://arxiv.org/pdf/2206.02915.pdf. + // Operators supported FP8 are Cast, CastLike, QuantizeLinear, DequantizeLinear. + // The computation usually happens inside a block quantize / dequantize + // fused by the runtime. + FLOAT8E4M3FN = 17; // float 8, mostly used for coefficients, supports nan, not inf + FLOAT8E4M3FNUZ = 18; // float 8, mostly used for coefficients, supports nan, not inf, no negative zero + FLOAT8E5M2 = 19; // follows IEEE 754, supports nan, inf, mostly used for gradients + FLOAT8E5M2FNUZ = 20; // follows IEEE 754, supports nan, not inf, mostly used for gradients, no negative zero + + // 4-bit data-types + UINT4 = 21; // Unsigned integer in range [0, 15] + INT4 = 22; // Signed integer in range [-8, 7], using two's-complement representation + + // Future extensions go here. + } + + // The shape of the tensor. + repeated int64 dims = 1; + + // The data type of the tensor. + // This field MUST have a valid TensorProto.DataType value + optional int32 data_type = 2; + + // For very large tensors, we may want to store them in chunks, in which + // case the following fields will specify the segment that is stored in + // the current TensorProto. + message Segment { + optional int64 begin = 1; + optional int64 end = 2; + } + optional Segment segment = 3; + + // Tensor content must be organized in row-major order. + // + // Depending on the data_type field, exactly one of the fields below with + // name ending in _data is used to store the elements of the tensor. + + // For float and complex64 values + // Complex64 tensors are encoded as a single array of floats, + // with the real components appearing in odd numbered positions, + // and the corresponding imaginary component appearing in the + // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] + // is encoded as [1.0, 2.0 ,3.0 ,4.0] + // When this field is present, the data_type field MUST be FLOAT or COMPLEX64. + repeated float float_data = 4 [packed = true]; + + // For int32, uint8, int8, uint16, int16, uint4, int4, bool, float8 and float16 values + // float16 and float8 values must be bit-wise converted to an uint16_t prior + // to writing to the buffer. + // uint4 and int4 values must be packed to 4bitx2 prior to writing to the buffer, the first element is stored in + // the 4 LSB and the second element is stored in the 4 MSB. + // When this field is present, the data_type field MUST be + // INT32, INT16, INT8, INT4, UINT16, UINT8, UINT4, BOOL, FLOAT16, BFLOAT16, FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ + repeated int32 int32_data = 5 [packed = true]; + + // For strings. + // Each element of string_data is a UTF-8 encoded Unicode + // string. No trailing null, no leading BOM. The protobuf "string" + // scalar type is not used to match ML community conventions. + // When this field is present, the data_type field MUST be STRING + repeated bytes string_data = 6; + + // For int64. + // When this field is present, the data_type field MUST be INT64 + repeated int64 int64_data = 7 [packed = true]; + + // Optionally, a name for the tensor. + optional string name = 8; // namespace Value + + // A human-readable documentation for this tensor. Markdown is allowed. + optional string doc_string = 12; + + // Serializations can either use one of the fields above, or use this + // raw bytes field. The only exception is the string case, where one is + // required to store the content in the repeated bytes string_data field. + // + // When this raw_data field is used to store tensor value, elements MUST + // be stored in as fixed-width, little-endian order. + // Floating-point data types MUST be stored in IEEE 754 format. + // Complex64 elements must be written as two consecutive FLOAT values, real component first. + // Complex128 elements must be written as two consecutive DOUBLE values, real component first. + // Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false). + // uint4 and int4 values must be packed to 4bitx2, the first element is stored in the 4 LSB and the second element is stored in the 4 MSB. + // + // Note: the advantage of specific field rather than the raw_data field is + // that in some cases (e.g. int data), protobuf does a better packing via + // variable length storage, and may lead to smaller binary footprint. + // When this field is present, the data_type field MUST NOT be STRING or UNDEFINED + optional bytes raw_data = 9; + + // Data can be stored inside the protobuf file using type-specific fields or raw_data. + // Alternatively, raw bytes data can be stored in an external file, using the external_data field. + // external_data stores key-value pairs describing data location. Recognized keys are: + // - "location" (required) - POSIX filesystem path relative to the directory where the ONNX + // protobuf model was stored + // - "offset" (optional) - position of byte at which stored data begins. Integer stored as string. + // Offset values SHOULD be multiples 4096 (page size) to enable mmap support. + // - "length" (optional) - number of bytes containing data. Integer stored as string. + // - "checksum" (optional) - SHA1 digest of file specified in under 'location' key. + repeated StringStringEntryProto external_data = 13; + + // Location of the data for this tensor. MUST be one of: + // - DEFAULT - data stored inside the protobuf message. Data is stored in raw_data (if set) otherwise in type-specified field. + // - EXTERNAL - data stored in an external location as described by external_data field. + enum DataLocation { + DEFAULT = 0; + EXTERNAL = 1; + } + + // If value not set, data is stored in raw_data (if set) otherwise in type-specified field. + optional DataLocation data_location = 14; + + // For double + // Complex128 tensors are encoded as a single array of doubles, + // with the real components appearing in odd numbered positions, + // and the corresponding imaginary component appearing in the + // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] + // is encoded as [1.0, 2.0 ,3.0 ,4.0] + // When this field is present, the data_type field MUST be DOUBLE or COMPLEX128 + repeated double double_data = 10 [packed = true]; + + // For uint64 and uint32 values + // When this field is present, the data_type field MUST be + // UINT32 or UINT64 + repeated uint64 uint64_data = 11 [packed = true]; + + // Named metadata values; keys should be distinct. + repeated StringStringEntryProto metadata_props = 16; +} + +// A serialized sparse-tensor value +message SparseTensorProto { + // The sequence of non-default values are encoded as a tensor of shape [NNZ]. + // The default-value is zero for numeric tensors, and empty-string for string tensors. + // values must have a non-empty name present which serves as a name for SparseTensorProto + // when used in sparse_initializer list. + optional TensorProto values = 1; + + // The indices of the non-default values, which may be stored in one of two formats. + // (a) Indices can be a tensor of shape [NNZ, rank] with the [i,j]-th value + // corresponding to the j-th index of the i-th value (in the values tensor). + // (b) Indices can be a tensor of shape [NNZ], in which case the i-th value + // must be the linearized-index of the i-th value (in the values tensor). + // The linearized-index can be converted into an index tuple (k_1,...,k_rank) + // using the shape provided below. + // The indices must appear in ascending order without duplication. + // In the first format, the ordering is lexicographic-ordering: + // e.g., index-value [1,4] must appear before [2,1] + optional TensorProto indices = 2; + + // The shape of the underlying dense-tensor: [dim_1, dim_2, ... dim_rank] + repeated int64 dims = 3; +} + +// Defines a tensor shape. A dimension can be either an integer value +// or a symbolic variable. A symbolic variable represents an unknown +// dimension. +message TensorShapeProto { + message Dimension { + oneof value { + int64 dim_value = 1; + string dim_param = 2; // namespace Shape + }; + // Standard denotation can optionally be used to denote tensor + // dimensions with standard semantic descriptions to ensure + // that operations are applied to the correct axis of a tensor. + // Refer to https://github.com/onnx/onnx/blob/main/docs/DimensionDenotation.md#denotation-definition + // for pre-defined dimension denotations. + optional string denotation = 3; + }; + repeated Dimension dim = 1; +} + +// Types +// +// The standard ONNX data types. +message TypeProto { + + message Tensor { + // This field MUST NOT have the value of UNDEFINED + // This field MUST have a valid TensorProto.DataType value + // This field MUST be present for this version of the IR. + optional int32 elem_type = 1; + optional TensorShapeProto shape = 2; + } + + // repeated T + message Sequence { + // The type and optional shape of each element of the sequence. + // This field MUST be present for this version of the IR. + optional TypeProto elem_type = 1; + }; + + // map + message Map { + // This field MUST have a valid TensorProto.DataType value + // This field MUST be present for this version of the IR. + // This field MUST refer to an integral type ([U]INT{8|16|32|64}) or STRING + optional int32 key_type = 1; + // This field MUST be present for this version of the IR. + optional TypeProto value_type = 2; + }; + + // wrapper for Tensor, Sequence, or Map + message Optional { + // The type and optional shape of the element wrapped. + // This field MUST be present for this version of the IR. + // Possible values correspond to OptionalProto.DataType enum + optional TypeProto elem_type = 1; + }; + + + message SparseTensor { + // This field MUST NOT have the value of UNDEFINED + // This field MUST have a valid TensorProto.DataType value + // This field MUST be present for this version of the IR. + optional int32 elem_type = 1; + optional TensorShapeProto shape = 2; + } + + + oneof value { + // The type of a tensor. + Tensor tensor_type = 1; + + // NOTE: DNN-only implementations of ONNX MAY elect to not support non-tensor values + // as input and output to graphs and nodes. These types are needed to naturally + // support classical ML operators. DNN operators SHOULD restrict their input + // and output types to tensors. + + // The type of a sequence. + Sequence sequence_type = 4; + + // The type of a map. + Map map_type = 5; + + // The type of an optional. + Optional optional_type = 9; + + + // Type of the sparse tensor + SparseTensor sparse_tensor_type = 8; + + } + + // An optional denotation can be used to denote the whole + // type with a standard semantic description as to what is + // stored inside. Refer to https://github.com/onnx/onnx/blob/main/docs/TypeDenotation.md#type-denotation-definition + // for pre-defined type denotations. + optional string denotation = 6; +} + +// Operator Sets +// +// OperatorSets are uniquely identified by a (domain, opset_version) pair. +message OperatorSetIdProto { + // The domain of the operator set being identified. + // The empty string ("") or absence of this field implies the operator + // set that is defined as part of the ONNX specification. + // This field MUST be present in this version of the IR when referring to any other operator set. + optional string domain = 1; + + // The version of the operator set being identified. + // This field MUST be present in this version of the IR. + optional int64 version = 2; +} + +// Operator/function status. +enum OperatorStatus { + EXPERIMENTAL = 0; + STABLE = 1; +} + +message FunctionProto { + // The name of the function, similar to op_type in NodeProto. + // This is part of the unique-id (domain, name, overload) of FunctionProtos in a model. + optional string name = 1; + + // Deprecated since IR Version 8 + // optional int64 since_version = 2; + reserved 2; + reserved "since_version"; + + // Deprecated since IR Version 8 + // optional OperatorStatus status = 3; + reserved 3; + reserved "status"; + + // The inputs and outputs of the function. + repeated string input = 4; + repeated string output = 5; + + // The attribute parameters of the function. + // It is for function parameters without default values. + repeated string attribute = 6; + + // The attribute protos of the function. + // It is for function attributes with default values. + // A function attribute shall be represented either as + // a string attribute or an AttributeProto, not both. + repeated AttributeProto attribute_proto = 11; + + // The nodes in the function. + repeated NodeProto node = 7; + // A human-readable documentation for this function. Markdown is allowed. + optional string doc_string = 8; + + // The OperatorSets this function body (graph) relies on. + // + // All nodes in the function body (graph) will bind against the operator + // with the same-domain/same-op_type operator with the HIGHEST version + // in the referenced operator sets. This means at most one version can be relied + // for one domain. + // + // The operator sets imported by FunctionProto should be compatible with the ones + // imported by ModelProto. Example, if same operator set say 'A' is imported by FunctionProto + // and ModelProto then versions for the operator set may be different but, + // the operator schema returned for op_type, domain, version combination + // for both the versions should be same. + + repeated OperatorSetIdProto opset_import = 9; + + // The domain which this function belongs to. + // This is part of the unique-id (domain, name, overload) of FunctionProtos in a model. + optional string domain = 10; + + // The overload identifier of the function. + // This is part of the unique-id (domain, name, overload) of FunctionProtos in a model. + optional string overload = 13; + + // Information for the values in the function. The ValueInfoProto.name's + // must be distinct and refer to names in the function (including inputs, + // outputs, and intermediate values). It is optional for a value to appear + // in value_info list. + repeated ValueInfoProto value_info = 12; + + // Named metadata values; keys should be distinct. + repeated StringStringEntryProto metadata_props = 14; +} + +// For using protobuf-lite +option optimize_for = LITE_RUNTIME; diff --git a/tools/onnx-subgraph/onnx_subgraph_ut.py b/tools/onnx-subgraph/onnx_subgraph_ut.py new file mode 100644 index 00000000000..13428228989 --- /dev/null +++ b/tools/onnx-subgraph/onnx_subgraph_ut.py @@ -0,0 +1,62 @@ +import unittest +import os +import sys +import extract_onnx_lib +import shutil + +def onnx_parser_test(args): + #exe = './onnx-subgraph ' + '--onnx=test.onnx' + exe = './onnx-subgraph ' + args + rec = os.system(exe) + +class ONNX_Parser_Test(unittest.TestCase): + def test_parse_result_exception(self): + ret = os.path.exists('./subgraphs_ios.txt') + if ret: + os.remove('./subgraphs_ios.txt') + onnx_parser_test('--onnx=no_file.onnx') + ret = os.path.exists('./subgraphs_ios.txt') + self.assertEqual(ret, False) + + def test_parse_result_normal(self): + ret = os.path.exists('./subgraphs_ios.txt') + if ret: + os.remove('./subgraphs_ios.txt') + + onnx_parser_test('--onnx=test.onnx') + ret = os.path.exists('./subgraphs_ios.txt') + self.assertEqual(ret, True) + + def test_subgraph_normal(self): + ret = os.path.exists('./subgraphs') + if ret: + shutil.rmtree(path='./subgraphs') + + extract_onnx_lib.split_onnx_ios('./subgraphs_ios.txt','./test.onnx') + ret = os.path.exists('./subgraphs') + self.assertEqual(ret, True) + + ret = os.path.exists('./subgraphs/CPU') + self.assertEqual(ret, True) + + ret = os.path.exists('./subgraphs/NPU') + self.assertEqual(ret, True) + + ret = os.path.exists('./subgraphs/CPU/CPUsubgraph15.onnx') + self.assertEqual(ret, True) + + ret = os.path.exists('./subgraphs/NPU/NPUsubgraph15.onnx') + self.assertEqual(ret, True) + + def test_subgraph_exception(self): + ret = os.path.exists('./subgraphs') + if ret: + shutil.rmtree(path='./subgraphs') + + extract_onnx_lib.split_onnx_ios('./subgraphs_ios.txt','./fake.onnx') + ret = os.path.exists('./subgraphs') + self.assertEqual(ret, False) + + +if __name__ == '__main__': + unittest.main() diff --git a/tools/onnx-subgraph/quant.py b/tools/onnx-subgraph/quant.py new file mode 100644 index 00000000000..ee9e57665f0 --- /dev/null +++ b/tools/onnx-subgraph/quant.py @@ -0,0 +1,425 @@ +# Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import os +import time +from types import MethodType + +import torch +import torch.nn as nn +from torch import Tensor +from torch.nn import functional as F +from tqdm import tqdm +from sklearn.cluster import KMeans +from torch.nn import functional as F +import numpy as np +import pdb +def quant_transmartix(x, bits=8): + # Quantizes the input tensor and returns the quantized tensor and its integer representation. + if(x.max() == x.min()): + return x, 0 + n = 2 ** (bits - 1) - 1 + act_scale = (x.max() - x.min()) / 2 / n + zero_point = (x.min() + x.max()) / 2 + aint = ((x - zero_point) / act_scale).round().clamp(-n - 1, n) + xq = aint * act_scale + zero_point + return xq, aint + +def quant_transmartix1(x, bits=8): + # Computes the projection matrix using singular value decomposition and quantizes it. + cov = torch.matmul(im, im.t()) / im.shape[1] + if(x.max() == x.min()): + return x, 0 + n = 2 ** (bits - 1) - 1 + act_scale = (x.max() - x.min()) / 2 / n + zero_point = (x.min() + x.max()) / 2 + aint = ((x - zero_point) / act_scale).round().clamp(-n - 1, n) + return aint, act_scale, zero_point + +def get_projection_matrix(im, eigenVar,num_bits=8): + # covariance matrix + cov = torch.matmul(im, im.t()) / im.shape[1] + # svd + u, s, _ = torch.svd(cov) + u,_ = quant_transmartix(u,16) + return u, s + +def comp(x, rate, output_dir, count ,transu, inb, num_bits, layer): + # Compresses the input tensor using a transformation matrix and quantizes the result. + if(len(x.shape) == 2): + B,C = x.shape + x_reshape = x + elif(len(x.shape) == 3): + B,C,H = x.shape + x_reshape = x.permute(1,0,2).reshape(C,-1) + elif(len(x.shape) == 4): + B, C, H, W = x.shape + x_reshape = x.permute(1, 0, 2, 3).reshape(C, -1) + else: + raise NotImplementedError + if(count==1): + u,s = get_projection_matrix(x_reshape, rate, num_bits) + x_trans = torch.matmul(u.t(), x_reshape) + x_trans, x_trans_int = quant_transmartix(x_trans,num_bits) + channel_max = x_trans_int.max(-1)[0].reshape(1,-1) + channel_min = x_trans_int.min(-1)[0].reshape(1,-1) + channel_dif = channel_max-channel_min + channel_dif[torch.where(channel_dif==0)]=1 + bits = torch.ceil(torch.log2(channel_dif)) + max_min = torch.cat([channel_max,channel_min],dim=0) + x_return = torch.matmul(u, x_trans) + x_return, x_return_int = quant_transmartix(x_return,num_bits) + ru = u + rbits=max_min + elif(count<=100): + x_trans = torch.matmul(transu.t(), x_reshape) + x_trans,x_trans_int = quant_transmartix(x_trans,num_bits) + channel_max = x_trans_int.max(-1)[0].reshape(1,-1) + channel_min = x_trans_int.min(-1)[0].reshape(1,-1) + max_min = torch.cat([channel_max,channel_min],dim=0) + x_return = torch.matmul(transu, x_trans) + x_return,x_return_int = quant_transmartix(x_return,num_bits) + ru = None + rbits=max_min + else: + x_trans = torch.matmul(transu.t(), x_reshape) + x_trans_int,act_scale,zero_point = quant_transmartix1(x_trans,num_bits) + inb_expend = inb[:,:,None].repeat(1,1,H*W) + mask_clip_max = torch.where(x_trans_int>inb_expend[0]) + mask_clip_min = torch.where(x_trans_int &Subgraphs, std::string device, + std::vector> &subgraphs_inputs, + std::vector> &subgraphs_outputs) +{ + std::cout << "Generate Cut Instruction for Target_NPU" << std::endl; + // open file + std::string file_name = device + "CutInstruction.txt"; + std::ofstream outFile(file_name); + if (!outFile.is_open()) { + std::cerr << "Error opening file." << std::endl; + exit(0); + } + for (size_t i = 0; i < Subgraphs.size(); i++) { + // default parameters + std::string modelFile = onnxFile; + std::string dataScaleDiv = "255"; + std::string postprocess = "save_and_top5"; + + std::unordered_set graphInputs = subgraphs_inputs[i]; + std::unordered_set graphOutputs = subgraphs_outputs[i]; + + std::string inputName = "\""; + for (const auto& input : graphInputs) { + inputName = inputName + input.name + ";"; + } + // delete last semicolon + if (!inputName.empty() && inputName.back() == ';') { + inputName.pop_back(); + } + inputName = inputName + "\""; + std::string outputName = "\""; + for (const auto& output : graphOutputs) { + outputName = outputName + output.name + ";"; + } + // delete last semicolon + if (!outputName.empty() && outputName.back() == ';') { + outputName.pop_back(); + } + outputName = outputName + "\""; + + std::string inputShape = "\""; + for (const auto& input : graphInputs) { + for (const auto& dim : input.shape) { + inputShape = inputShape + std::to_string(dim) + " "; + } + // delete last space + if (!inputShape.empty() && inputShape.back() == ' ') { + inputShape.pop_back(); + } + inputShape = inputShape + ";"; + } + // delete last semicolon + if (!inputShape.empty() && inputShape.back() == ';') { + inputShape.pop_back(); + } + inputShape = inputShape + "\""; + + std::string calibrateDataset = device + "_Subgraphs_" + std::to_string(i) + ".npz"; + std::string quantizationScheme = "int8_asym"; + + } + + outFile.close(); + +} diff --git a/tools/onnx-subgraph/src/lib/graph.cpp b/tools/onnx-subgraph/src/lib/graph.cpp new file mode 100644 index 00000000000..8c5c457b40d --- /dev/null +++ b/tools/onnx-subgraph/src/lib/graph.cpp @@ -0,0 +1,213 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph.h" +#include "partition.h" + +std::unordered_set getInitializer(const onnx::GraphProto& graph) { + std::unordered_set initializerNames; + for (const auto& initializer : graph.initializer()) { + NodeTensor nt; + nt.name = initializer.name(); + std::vector shape; + for (const auto& dim : initializer.dims()) { + shape.push_back(dim); + } + nt.shape = shape; + initializerNames.insert(nt); + } + return initializerNames; +} + +std::unordered_set getIOvalue(const onnx::GraphProto& graph) { + std::unordered_set IOvalue; + for (const auto& value_info : graph.value_info()) { + NodeTensor nt; + nt.name = value_info.name(); + + std::vector shape; + for (const auto& dim : value_info.type().tensor_type().shape().dim()) { + shape.push_back(dim.dim_value()); + } + nt.shape = shape; + IOvalue.insert(nt); + } + for (auto value_info : graph.input()) { + NodeTensor nt; + nt.name = value_info.name(); + + std::vector shape; + for (const auto& dim : value_info.type().tensor_type().shape().dim()) { + shape.push_back(dim.dim_value()); + } + nt.shape = shape; + IOvalue.insert(nt); + } + for (auto value_info : graph.output()) { + NodeTensor nt; + nt.name = value_info.name(); + + std::vector shape; + for (const auto& dim : value_info.type().tensor_type().shape().dim()) { + shape.push_back(dim.dim_value()); + } + nt.shape = shape; + IOvalue.insert(nt); + } + return IOvalue; +} +/** +* @brief Finds a NodeTensor with the specified name in the given set of NodeTensors. +* +* @param [in] name The name of the NodeTensor to find. +* @param [in] tensors The set of NodeTensors to search within. +* @pre The tensors set should be valid and contain NodeTensor objects. +* @post None +* @exception None +* @return An iterator to the found NodeTensor if it exists, otherwise an iterator to the end of the set. +*/ +std::unordered_set::const_iterator isInputFromInitializer(const std::string& name, const std::unordered_set& tensors) { + return std::find_if(tensors.begin(), tensors.end(), [&](const NodeTensor& tensor) { return tensor.name == name; }); +} + +void determineGraphInput(const onnx::GraphProto& g, const std::unordered_set& initializerNames, + std::unordered_set &graphInputs) { + std::unordered_set allnodeOutputs; + + // Iterate over each node in the graph to collect all outputs + for (const auto& node : g.node()) { + // Get the output list of the current node + const auto& outputs = node.output(); + + // Insert each output into the set of all node outputs + for (const auto& output : outputs) { + allnodeOutputs.insert(output); + } + } + + // Iterate over each node in the graph to identify inputs not produced by any node + for (const auto& node : g.node()) { + // Get the input list of the current node + const auto& inputs = node.input(); + + // Check each input to determine if it is an external input to the graph + for (const auto& input : inputs) { + // If the input is not found in the set of all node outputs, it is a graph input + if (std::find(allnodeOutputs.begin(), allnodeOutputs.end(), input) == allnodeOutputs.end()) { + auto iter = isInputFromInitializer(input, initializerNames); + NodeTensor nt; + nt.name = input; + if (iter != initializerNames.end()) { + graphInputs.insert(*iter); + } + } + } + } +} + +void determineGraphOutput(const onnx::GraphProto& originalGraph, const onnx::GraphProto& g, std::vector> &allgraphInputs_1, + std::vector> &allgraphInputs_2, std::unordered_set &graphOutputs) { + auto allgraphInputs = allgraphInputs_1; + allgraphInputs.insert(allgraphInputs.end(), allgraphInputs_2.begin(), allgraphInputs_2.end()); + for (const auto& node : g.node()) { + const auto& outputs = node.output(); + for (const auto& output : outputs) { + int flag = 0; + for (auto value_info : originalGraph.output()) { + if (value_info.name() == output) { + NodeTensor nt; + nt.name = value_info.name(); + std::cout << nt.name << std::endl; + std::vector shape; + for (const auto& dim : value_info.type().tensor_type().shape().dim()) { + shape.push_back(dim.dim_value()); + } + nt.shape = shape; + graphOutputs.insert(nt); + flag = 1; + break; + } + } + if (flag) { + continue; + } + for (size_t i = 0; i < allgraphInputs.size(); i++) { + for (auto& input : allgraphInputs[i]) { + if (input.name == output) { + graphOutputs.insert(input); + flag = 1; + break; + } + } + if (flag) { + break; + } + } + } + } +} +std::string findInputNode(const onnx::GraphProto &g, const std::string& outputTensorName) { + std::string node_name = ""; + for (const auto& node : g.node()) { + for (const auto& output : node.output()) { + if (output == outputTensorName) { + node_name = node.name(); + } + } + } + return node_name; +} + +std::unordered_set collectNodeNames(const onnx::GraphProto& graph) { + std::unordered_set nodeNames; + for (const auto& node : graph.node()) { + nodeNames.insert(node.name()); + } + return nodeNames; +} + +void mergeGraphs(onnx::GraphProto& targetGraph, onnx::GraphProto& sourceGraph) { + std::cout<<"size before merged: "< buffer(size); + input.read(buffer.data(), size); + model.ParseFromArray(buffer.data(), size); // parse protobuf + return model.graph(); +} diff --git a/tools/onnx-subgraph/src/lib/jsoncpp.cpp b/tools/onnx-subgraph/src/lib/jsoncpp.cpp new file mode 100644 index 00000000000..26afc825efc --- /dev/null +++ b/tools/onnx-subgraph/src/lib/jsoncpp.cpp @@ -0,0 +1,4951 @@ +/// Json-cpp amalgated source (http://jsoncpp.sourceforge.net/). +/// It is intended to be used with #include "json/json.h" + +// ////////////////////////////////////////////////////////////////////// +// Beginning of content of file: LICENSE +// ////////////////////////////////////////////////////////////////////// + +/* +The JsonCpp library's source code, including accompanying documentation, +tests and demonstration applications, are licensed under the following +conditions... + +The author (Baptiste Lepilleur) explicitly disclaims copyright in all +jurisdictions which recognize such a disclaimer. In such jurisdictions, +this software is released into the Public Domain. + +In jurisdictions which do not recognize Public Domain property (e.g. Germany as of +2010), this software is Copyright (c) 2007-2010 by Baptiste Lepilleur, and is +released under the terms of the MIT License (see below). + +In jurisdictions which recognize Public Domain property, the user of this +software may choose to accept it either as 1) Public Domain, 2) under the +conditions of the MIT License (see below), or 3) under the terms of dual +Public Domain/MIT License conditions described here, as they choose. + +The MIT License is about as close to Public Domain as a license can get, and is +described in clear, concise terms at: + + http://en.wikipedia.org/wiki/MIT_License + +The full text of the MIT License follows: + +======================================================================== +Copyright (c) 2007-2010 Baptiste Lepilleur + +Permission is hereby granted, free of charge, to any person +obtaining a copy of this software and associated documentation +files (the "Software"), to deal in the Software without +restriction, including without limitation the rights to use, copy, +modify, merge, publish, distribute, sublicense, and/or sell copies +of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS +BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN +ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +======================================================================== +(END LICENSE TEXT) + +The MIT license is compatible with both the GPL and commercial +software, affording one all of the rights of Public Domain with the +minor nuisance of being required to keep the above copyright notice +and license text in the source code. Note also that by accepting the +Public Domain "license" you can re-license your copy using whatever +license you like. + +*/ + +// ////////////////////////////////////////////////////////////////////// +// End of content of file: LICENSE +// ////////////////////////////////////////////////////////////////////// + + + + + + +#include "json.h" + +#ifndef JSON_IS_AMALGAMATION +#error "Compile with -I PATH_TO_JSON_DIRECTORY" +#endif + + +// ////////////////////////////////////////////////////////////////////// +// Beginning of content of file: src/lib_json/json_tool.h +// ////////////////////////////////////////////////////////////////////// + +// Copyright 2007-2010 Baptiste Lepilleur +// Distributed under MIT license, or public domain if desired and +// recognized in your jurisdiction. +// See file LICENSE for detail or copy at http://jsoncpp.sourceforge.net/LICENSE + +#ifndef LIB_JSONCPP_JSON_TOOL_H_INCLUDED +#define LIB_JSONCPP_JSON_TOOL_H_INCLUDED + +/* This header provides common string manipulation support, such as UTF-8, + * portable conversion from/to string... + * + * It is an internal header that must not be exposed. + */ + +namespace Json { + +/// Converts a unicode code-point to UTF-8. +static inline std::string codePointToUTF8(unsigned int cp) { + std::string result; + + // based on description from http://en.wikipedia.org/wiki/UTF-8 + + if (cp <= 0x7f) { + result.resize(1); + result[0] = static_cast(cp); + } else if (cp <= 0x7FF) { + result.resize(2); + result[1] = static_cast(0x80 | (0x3f & cp)); + result[0] = static_cast(0xC0 | (0x1f & (cp >> 6))); + } else if (cp <= 0xFFFF) { + result.resize(3); + result[2] = static_cast(0x80 | (0x3f & cp)); + result[1] = static_cast(0x80 | (0x3f & (cp >> 6))); + result[0] = static_cast(0xE0 | (0xf & (cp >> 12))); + } else if (cp <= 0x10FFFF) { + result.resize(4); + result[3] = static_cast(0x80 | (0x3f & cp)); + result[2] = static_cast(0x80 | (0x3f & (cp >> 6))); + result[1] = static_cast(0x80 | (0x3f & (cp >> 12))); + result[0] = static_cast(0xF0 | (0x7 & (cp >> 18))); + } + + return result; +} + +/// Returns true if ch is a control character (in range [1,31]). +static inline bool isControlCharacter(char ch) { return ch > 0 && ch <= 0x1F; } + +enum { + /// Constant that specify the size of the buffer that must be passed to + /// uintToString. + uintToStringBufferSize = 3 * sizeof(LargestUInt) + 1 +}; + +// Defines a char buffer for use with uintToString(). +typedef char UIntToStringBuffer[uintToStringBufferSize]; + +/** Converts an unsigned integer to string. + * @param value Unsigned interger to convert to string + * @param current Input/Output string buffer. + * Must have at least uintToStringBufferSize chars free. + */ +static inline void uintToString(LargestUInt value, char*& current) { + *--current = 0; + do { + *--current = static_cast(value % 10U + static_cast('0')); + value /= 10; + } while (value != 0); +} + +/** Change ',' to '.' everywhere in buffer. + * + * We had a sophisticated way, but it did not work in WinCE. + * @see https://github.com/open-source-parsers/jsoncpp/pull/9 + */ +static inline void fixNumericLocale(char* begin, char* end) { + while (begin < end) { + if (*begin == ',') { + *begin = '.'; + } + ++begin; + } +} + +} // namespace Json { + +#endif // LIB_JSONCPP_JSON_TOOL_H_INCLUDED + +// ////////////////////////////////////////////////////////////////////// +// End of content of file: src/lib_json/json_tool.h +// ////////////////////////////////////////////////////////////////////// + + + + + + +// ////////////////////////////////////////////////////////////////////// +// Beginning of content of file: src/lib_json/json_reader.cpp +// ////////////////////////////////////////////////////////////////////// + +// Copyright 2007-2011 Baptiste Lepilleur +// Distributed under MIT license, or public domain if desired and +// recognized in your jurisdiction. +// See file LICENSE for detail or copy at http://jsoncpp.sourceforge.net/LICENSE + +#if !defined(JSON_IS_AMALGAMATION) +#include +#include +#include +#include "json_tool.h" +#endif // if !defined(JSON_IS_AMALGAMATION) +#include +#include +#include +#include +#include +#include +#include +#include +#include +#if defined(__BORLANDC__) +#include +#endif +#if defined(_MSC_VER) +#if !defined(WINCE) && defined(__STDC_SECURE_LIB__) && _MSC_VER >= 1500 // VC++ 9.0 and above +#define snprintf sprintf_s +#elif _MSC_VER >= 1900 // VC++ 14.0 and above +#define snprintf std::snprintf +#else +#define snprintf _snprintf +#endif +#elif defined(__ANDROID__) +#define snprintf snprintf +#elif __cplusplus >= 201103L +#define snprintf std::snprintf +#endif + +#if defined(_MSC_VER) && _MSC_VER >= 1400 // VC++ 8.0 +// Disable warning about strdup being deprecated. +#pragma warning(disable : 4996) +#endif + +static int const stackLimit_g = 1000; +static int stackDepth_g = 0; // see readValue() + +namespace Json { + +#if JSON_HAS_UNIQUE_PTR +typedef std::unique_ptr const CharReaderPtr; +#else +typedef std::auto_ptr CharReaderPtr; +#endif + +// Implementation of class Features +// //////////////////////////////// + +Features::Features() + : allowComments_(true), strictRoot_(false) +{} +Features Features::all() { return Features(); } + +Features Features::strictMode() { + Features features; + features.allowComments_ = false; + features.strictRoot_ = true; + return features; +} + +// Implementation of class Reader +// //////////////////////////////// + +static bool containsNewLine(Reader::Location begin, Reader::Location end) { + for (; begin < end; ++begin) + if (*begin == '\n' || *begin == '\r') + return true; + return false; +} + +// Class Reader +// ////////////////////////////////////////////////////////////////// + +Reader::Reader() + : errors_(), document_(), begin_(), end_(), current_(), lastValueEnd_(), + lastValue_(), commentsBefore_(), features_(Features::all()), + collectComments_() {} + +Reader::Reader(const Features& features) + : errors_(), document_(), begin_(), end_(), current_(), lastValueEnd_(), + lastValue_(), commentsBefore_(), features_(features), collectComments_() { +} + +bool +Reader::parse(const std::string& document, Value& root, bool collectComments) { + document_ = document; + const char* begin = document_.c_str(); + const char* end = begin + document_.length(); + return parse(begin, end, root, collectComments); +} + +bool Reader::parse(std::istream& sin, Value& root, bool collectComments) { + // std::istream_iterator begin(sin); + // std::istream_iterator end; + // Those would allow streamed input from a file, if parse() were a + // template function. + + // Since std::string is reference-counted, this at least does not + // create an extra copy. + std::string doc; + std::getline(sin, doc, (char)EOF); + return parse(doc, root, collectComments); +} + +bool Reader::parse(const char* beginDoc, + const char* endDoc, + Value& root, + bool collectComments) { + if (!features_.allowComments_) { + collectComments = false; + } + + begin_ = beginDoc; + end_ = endDoc; + collectComments_ = collectComments; + current_ = begin_; + lastValueEnd_ = 0; + lastValue_ = 0; + commentsBefore_ = ""; + errors_.clear(); + while (!nodes_.empty()) + nodes_.pop(); + nodes_.push(&root); + + stackDepth_g = 0; // Yes, this is bad coding, but options are limited. + bool successful = readValue(); + Token token; + skipCommentTokens(token); + if (collectComments_ && !commentsBefore_.empty()) + root.setComment(commentsBefore_, commentAfter); + if (features_.strictRoot_) { + if (!root.isArray() && !root.isObject()) { + // Set error location to start of doc, ideally should be first token found + // in doc + token.type_ = tokenError; + token.start_ = beginDoc; + token.end_ = endDoc; + addError( + "A valid JSON document must be either an array or an object value.", + token); + return false; + } + } + return successful; +} + +bool Reader::readValue() { + // This is a non-reentrant way to support a stackLimit. Terrible! + // But this deprecated class has a security problem: Bad input can + // cause a seg-fault. This seems like a fair, binary-compatible way + // to prevent the problem. + if (stackDepth_g >= stackLimit_g) throwRuntimeError("Exceeded stackLimit in readValue()."); + ++stackDepth_g; + + Token token; + skipCommentTokens(token); + bool successful = true; + + if (collectComments_ && !commentsBefore_.empty()) { + currentValue().setComment(commentsBefore_, commentBefore); + commentsBefore_ = ""; + } + + switch (token.type_) { + case tokenObjectBegin: + successful = readObject(token); + break; + case tokenArrayBegin: + successful = readArray(token); + break; + case tokenNumber: + successful = decodeNumber(token); + break; + case tokenString: + successful = decodeString(token); + break; + case tokenTrue: + { + Value v(true); + currentValue().swapPayload(v); + } + break; + case tokenFalse: + { + Value v(false); + currentValue().swapPayload(v); + } + break; + case tokenNull: + { + Value v; + currentValue().swapPayload(v); + } + break; + // Else, fall through... + default: + return addError("Syntax error: value, object or array expected.", token); + } + + if (collectComments_) { + lastValueEnd_ = current_; + lastValue_ = ¤tValue(); + } + + --stackDepth_g; + return successful; +} + +void Reader::skipCommentTokens(Token& token) { + if (features_.allowComments_) { + do { + readToken(token); + } while (token.type_ == tokenComment); + } else { + readToken(token); + } +} + +bool Reader::readToken(Token& token) { + skipSpaces(); + token.start_ = current_; + Char c = getNextChar(); + bool ok = true; + switch (c) { + case '{': + token.type_ = tokenObjectBegin; + break; + case '}': + token.type_ = tokenObjectEnd; + break; + case '[': + token.type_ = tokenArrayBegin; + break; + case ']': + token.type_ = tokenArrayEnd; + break; + case '"': + token.type_ = tokenString; + ok = readString(); + break; + case '/': + token.type_ = tokenComment; + ok = readComment(); + break; + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + case '-': + token.type_ = tokenNumber; + readNumber(); + break; + case 't': + token.type_ = tokenTrue; + ok = match("rue", 3); + break; + case 'f': + token.type_ = tokenFalse; + ok = match("alse", 4); + break; + case 'n': + token.type_ = tokenNull; + ok = match("ull", 3); + break; + case ',': + token.type_ = tokenArraySeparator; + break; + case ':': + token.type_ = tokenMemberSeparator; + break; + case 0: + token.type_ = tokenEndOfStream; + break; + default: + ok = false; + break; + } + if (!ok) + token.type_ = tokenError; + token.end_ = current_; + return true; +} + +void Reader::skipSpaces() { + while (current_ != end_) { + Char c = *current_; + if (c == ' ' || c == '\t' || c == '\r' || c == '\n') + ++current_; + else + break; + } +} + +bool Reader::match(Location pattern, int patternLength) { + if (end_ - current_ < patternLength) + return false; + int index = patternLength; + while (index--) + if (current_[index] != pattern[index]) + return false; + current_ += patternLength; + return true; +} + +bool Reader::readComment() { + Location commentBegin = current_ - 1; + Char c = getNextChar(); + bool successful = false; + if (c == '*') + successful = readCStyleComment(); + else if (c == '/') + successful = readCppStyleComment(); + if (!successful) + return false; + + if (collectComments_) { + CommentPlacement placement = commentBefore; + if (lastValueEnd_ && !containsNewLine(lastValueEnd_, commentBegin)) { + if (c != '*' || !containsNewLine(commentBegin, current_)) + placement = commentAfterOnSameLine; + } + + addComment(commentBegin, current_, placement); + } + return true; +} + +static std::string normalizeEOL(Reader::Location begin, Reader::Location end) { + std::string normalized; + normalized.reserve(end - begin); + Reader::Location current = begin; + while (current != end) { + char c = *current++; + if (c == '\r') { + if (current != end && *current == '\n') + // convert dos EOL + ++current; + // convert Mac EOL + normalized += '\n'; + } else { + normalized += c; + } + } + return normalized; +} + +void +Reader::addComment(Location begin, Location end, CommentPlacement placement) { + assert(collectComments_); + const std::string& normalized = normalizeEOL(begin, end); + if (placement == commentAfterOnSameLine) { + assert(lastValue_ != 0); + lastValue_->setComment(normalized, placement); + } else { + commentsBefore_ += normalized; + } +} + +bool Reader::readCStyleComment() { + while (current_ != end_) { + Char c = getNextChar(); + if (c == '*' && *current_ == '/') + break; + } + return getNextChar() == '/'; +} + +bool Reader::readCppStyleComment() { + while (current_ != end_) { + Char c = getNextChar(); + if (c == '\n') + break; + if (c == '\r') { + // Consume DOS EOL. It will be normalized in addComment. + if (current_ != end_ && *current_ == '\n') + getNextChar(); + // Break on Moc OS 9 EOL. + break; + } + } + return true; +} + +void Reader::readNumber() { + const char *p = current_; + char c = '0'; // stopgap for already consumed character + // integral part + while (c >= '0' && c <= '9') + c = (current_ = p) < end_ ? *p++ : 0; + // fractional part + if (c == '.') { + c = (current_ = p) < end_ ? *p++ : 0; + while (c >= '0' && c <= '9') + c = (current_ = p) < end_ ? *p++ : 0; + } + // exponential part + if (c == 'e' || c == 'E') { + c = (current_ = p) < end_ ? *p++ : 0; + if (c == '+' || c == '-') + c = (current_ = p) < end_ ? *p++ : 0; + while (c >= '0' && c <= '9') + c = (current_ = p) < end_ ? *p++ : 0; + } +} + +bool Reader::readString() { + Char c = 0; + while (current_ != end_) { + c = getNextChar(); + if (c == '\\') + getNextChar(); + else if (c == '"') + break; + } + return c == '"'; +} + +bool Reader::readObject(Token& /*tokenStart*/) { + Token tokenName; + std::string name; + Value init(objectValue); + currentValue().swapPayload(init); + while (readToken(tokenName)) { + bool initialTokenOk = true; + while (tokenName.type_ == tokenComment && initialTokenOk) + initialTokenOk = readToken(tokenName); + if (!initialTokenOk) + break; + if (tokenName.type_ == tokenObjectEnd && name.empty()) // empty object + return true; + name = ""; + if (tokenName.type_ == tokenString) { + if (!decodeString(tokenName, name)) + return recoverFromError(tokenObjectEnd); + } else { + break; + } + + Token colon; + if (!readToken(colon) || colon.type_ != tokenMemberSeparator) { + return addErrorAndRecover( + "Missing ':' after object member name", colon, tokenObjectEnd); + } + Value& value = currentValue()[name]; + nodes_.push(&value); + bool ok = readValue(); + nodes_.pop(); + if (!ok) // error already set + return recoverFromError(tokenObjectEnd); + + Token comma; + if (!readToken(comma) || + (comma.type_ != tokenObjectEnd && comma.type_ != tokenArraySeparator && + comma.type_ != tokenComment)) { + return addErrorAndRecover( + "Missing ',' or '}' in object declaration", comma, tokenObjectEnd); + } + bool finalizeTokenOk = true; + while (comma.type_ == tokenComment && finalizeTokenOk) + finalizeTokenOk = readToken(comma); + if (comma.type_ == tokenObjectEnd) + return true; + } + return addErrorAndRecover( + "Missing '}' or object member name", tokenName, tokenObjectEnd); +} + +bool Reader::readArray(Token& /*tokenStart*/) { + Value init(arrayValue); + currentValue().swapPayload(init); + skipSpaces(); + if (*current_ == ']') // empty array + { + Token endArray; + readToken(endArray); + return true; + } + int index = 0; + for (;;) { + Value& value = currentValue()[index++]; + nodes_.push(&value); + bool ok = readValue(); + nodes_.pop(); + if (!ok) // error already set + return recoverFromError(tokenArrayEnd); + + Token token; + // Accept Comment after last item in the array. + ok = readToken(token); + while (token.type_ == tokenComment && ok) { + ok = readToken(token); + } + bool badTokenType = + (token.type_ != tokenArraySeparator && token.type_ != tokenArrayEnd); + if (!ok || badTokenType) { + return addErrorAndRecover( + "Missing ',' or ']' in array declaration", token, tokenArrayEnd); + } + if (token.type_ == tokenArrayEnd) + break; + } + return true; +} + +bool Reader::decodeNumber(Token& token) { + Value decoded; + if (!decodeNumber(token, decoded)) + return false; + currentValue().swapPayload(decoded); + return true; +} + +bool Reader::decodeNumber(Token& token, Value& decoded) { + // Attempts to parse the number as an integer. If the number is + // larger than the maximum supported value of an integer then + // we decode the number as a double. + Location current = token.start_; + bool isNegative = *current == '-'; + if (isNegative) + ++current; + // TODO: Help the compiler do the div and mod at compile time or get rid of them. + Value::LargestUInt maxIntegerValue = + isNegative ? Value::LargestUInt(Value::maxLargestInt) + 1 + : Value::maxLargestUInt; + Value::LargestUInt threshold = maxIntegerValue / 10; + Value::LargestUInt value = 0; + while (current < token.end_) { + Char c = *current++; + if (c < '0' || c > '9') + return decodeDouble(token, decoded); + Value::UInt digit(c - '0'); + if (value >= threshold) { + // We've hit or exceeded the max value divided by 10 (rounded down). If + // a) we've only just touched the limit, b) this is the last digit, and + // c) it's small enough to fit in that rounding delta, we're okay. + // Otherwise treat this number as a double to avoid overflow. + if (value > threshold || current != token.end_ || + digit > maxIntegerValue % 10) { + return decodeDouble(token, decoded); + } + } + value = value * 10 + digit; + } + if (isNegative && value == maxIntegerValue) + decoded = Value::minLargestInt; + else if (isNegative) + decoded = -Value::LargestInt(value); + else if (value <= Value::LargestUInt(Value::maxInt)) + decoded = Value::LargestInt(value); + else + decoded = value; + return true; +} + +bool Reader::decodeDouble(Token& token) { + Value decoded; + if (!decodeDouble(token, decoded)) + return false; + currentValue().swapPayload(decoded); + return true; +} + +bool Reader::decodeDouble(Token& token, Value& decoded) { + double value = 0; + std::string buffer(token.start_, token.end_); + std::istringstream is(buffer); + if (!(is >> value)) + return addError("'" + std::string(token.start_, token.end_) + + "' is not a number.", + token); + decoded = value; + return true; +} + +bool Reader::decodeString(Token& token) { + std::string decoded_string; + if (!decodeString(token, decoded_string)) + return false; + Value decoded(decoded_string); + currentValue().swapPayload(decoded); + return true; +} + +bool Reader::decodeString(Token& token, std::string& decoded) { + decoded.reserve(token.end_ - token.start_ - 2); + Location current = token.start_ + 1; // skip '"' + Location end = token.end_ - 1; // do not include '"' + while (current != end) { + Char c = *current++; + if (c == '"') + break; + else if (c == '\\') { + if (current == end) + return addError("Empty escape sequence in string", token, current); + Char escape = *current++; + switch (escape) { + case '"': + decoded += '"'; + break; + case '/': + decoded += '/'; + break; + case '\\': + decoded += '\\'; + break; + case 'b': + decoded += '\b'; + break; + case 'f': + decoded += '\f'; + break; + case 'n': + decoded += '\n'; + break; + case 'r': + decoded += '\r'; + break; + case 't': + decoded += '\t'; + break; + case 'u': { + unsigned int unicode; + if (!decodeUnicodeCodePoint(token, current, end, unicode)) + return false; + decoded += codePointToUTF8(unicode); + } break; + default: + return addError("Bad escape sequence in string", token, current); + } + } else { + decoded += c; + } + } + return true; +} + +bool Reader::decodeUnicodeCodePoint(Token& token, + Location& current, + Location end, + unsigned int& unicode) { + + if (!decodeUnicodeEscapeSequence(token, current, end, unicode)) + return false; + if (unicode >= 0xD800 && unicode <= 0xDBFF) { + // surrogate pairs + if (end - current < 6) + return addError( + "additional six characters expected to parse unicode surrogate pair.", + token, + current); + unsigned int surrogatePair; + if (*(current++) == '\\' && *(current++) == 'u') { + if (decodeUnicodeEscapeSequence(token, current, end, surrogatePair)) { + unicode = 0x10000 + ((unicode & 0x3FF) << 10) + (surrogatePair & 0x3FF); + } else + return false; + } else + return addError("expecting another \\u token to begin the second half of " + "a unicode surrogate pair", + token, + current); + } + return true; +} + +bool Reader::decodeUnicodeEscapeSequence(Token& token, + Location& current, + Location end, + unsigned int& unicode) { + if (end - current < 4) + return addError( + "Bad unicode escape sequence in string: four digits expected.", + token, + current); + unicode = 0; + for (int index = 0; index < 4; ++index) { + Char c = *current++; + unicode *= 16; + if (c >= '0' && c <= '9') + unicode += c - '0'; + else if (c >= 'a' && c <= 'f') + unicode += c - 'a' + 10; + else if (c >= 'A' && c <= 'F') + unicode += c - 'A' + 10; + else + return addError( + "Bad unicode escape sequence in string: hexadecimal digit expected.", + token, + current); + } + return true; +} + +bool +Reader::addError(const std::string& message, Token& token, Location extra) { + ErrorInfo info; + info.token_ = token; + info.message_ = message; + info.extra_ = extra; + errors_.push_back(info); + return false; +} + +bool Reader::recoverFromError(TokenType skipUntilToken) { + int errorCount = int(errors_.size()); + Token skip; + for (;;) { + if (!readToken(skip)) + errors_.resize(errorCount); // discard errors caused by recovery + if (skip.type_ == skipUntilToken || skip.type_ == tokenEndOfStream) + break; + } + errors_.resize(errorCount); + return false; +} + +bool Reader::addErrorAndRecover(const std::string& message, + Token& token, + TokenType skipUntilToken) { + addError(message, token); + return recoverFromError(skipUntilToken); +} + +Value& Reader::currentValue() { return *(nodes_.top()); } + +Reader::Char Reader::getNextChar() { + if (current_ == end_) + return 0; + return *current_++; +} + +void Reader::getLocationLineAndColumn(Location location, + int& line, + int& column) const { + Location current = begin_; + Location lastLineStart = current; + line = 0; + while (current < location && current != end_) { + Char c = *current++; + if (c == '\r') { + if (*current == '\n') + ++current; + lastLineStart = current; + ++line; + } else if (c == '\n') { + lastLineStart = current; + ++line; + } + } + // column & line start at 1 + column = int(location - lastLineStart) + 1; + ++line; +} + +std::string Reader::getLocationLineAndColumn(Location location) const { + int line, column; + getLocationLineAndColumn(location, line, column); + char buffer[18 + 16 + 16 + 1]; + snprintf(buffer, sizeof(buffer), "Line %d, Column %d", line, column); + return buffer; +} + +// Deprecated. Preserved for backward compatibility +std::string Reader::getFormatedErrorMessages() const { + return getFormattedErrorMessages(); +} + +std::string Reader::getFormattedErrorMessages() const { + std::string formattedMessage; + for (Errors::const_iterator itError = errors_.begin(); + itError != errors_.end(); + ++itError) { + const ErrorInfo& error = *itError; + formattedMessage += + "* " + getLocationLineAndColumn(error.token_.start_) + "\n"; + formattedMessage += " " + error.message_ + "\n"; + if (error.extra_) + formattedMessage += + "See " + getLocationLineAndColumn(error.extra_) + " for detail.\n"; + } + return formattedMessage; +} + +// Reader +///////////////////////// + +// exact copy of Features +class OurFeatures { +public: + static OurFeatures all(); + OurFeatures(); + bool allowComments_; + bool strictRoot_; + bool allowDroppedNullPlaceholders_; + bool allowNumericKeys_; + bool allowSingleQuotes_; + bool failIfExtra_; + bool rejectDupKeys_; + bool allowSpecialFloats_; + int stackLimit_; +}; // OurFeatures + +// exact copy of Implementation of class Features +// //////////////////////////////// + +OurFeatures::OurFeatures() + : allowComments_(true), strictRoot_(false) + , allowDroppedNullPlaceholders_(false), allowNumericKeys_(false) + , allowSingleQuotes_(false) + , failIfExtra_(false) + , allowSpecialFloats_(false) +{ +} + +OurFeatures OurFeatures::all() { return OurFeatures(); } + +// Implementation of class Reader +// //////////////////////////////// + +// exact copy of Reader, renamed to OurReader +class OurReader { +public: + typedef char Char; + typedef const Char* Location; + struct StructuredError { + size_t offset_start; + size_t offset_limit; + std::string message; + }; + + OurReader(OurFeatures const& features); + bool parse(const char* beginDoc, + const char* endDoc, + Value& root, + bool collectComments = true); + std::string getFormattedErrorMessages() const; + +private: + OurReader(OurReader const&); // no impl + void operator=(OurReader const&); // no impl + + enum TokenType { + tokenEndOfStream = 0, + tokenObjectBegin, + tokenObjectEnd, + tokenArrayBegin, + tokenArrayEnd, + tokenString, + tokenNumber, + tokenTrue, + tokenFalse, + tokenNull, + tokenNaN, + tokenPosInf, + tokenNegInf, + tokenArraySeparator, + tokenMemberSeparator, + tokenComment, + tokenError + }; + + class Token { + public: + TokenType type_; + Location start_; + Location end_; + }; + + class ErrorInfo { + public: + Token token_; + std::string message_; + Location extra_; + }; + + typedef std::deque Errors; + + bool readToken(Token& token); + void skipSpaces(); + bool match(Location pattern, int patternLength); + bool readComment(); + bool readCStyleComment(); + bool readCppStyleComment(); + bool readString(); + bool readStringSingleQuote(); + bool readNumber(bool checkInf); + bool readValue(); + bool readObject(Token& token); + bool readArray(Token& token); + bool decodeNumber(Token& token); + bool decodeNumber(Token& token, Value& decoded); + bool decodeString(Token& token); + bool decodeString(Token& token, std::string& decoded); + bool decodeDouble(Token& token); + bool decodeDouble(Token& token, Value& decoded); + bool decodeUnicodeCodePoint(Token& token, + Location& current, + Location end, + unsigned int& unicode); + bool decodeUnicodeEscapeSequence(Token& token, + Location& current, + Location end, + unsigned int& unicode); + bool addError(const std::string& message, Token& token, Location extra = 0); + bool recoverFromError(TokenType skipUntilToken); + bool addErrorAndRecover(const std::string& message, + Token& token, + TokenType skipUntilToken); + void skipUntilSpace(); + Value& currentValue(); + Char getNextChar(); + void + getLocationLineAndColumn(Location location, int& line, int& column) const; + std::string getLocationLineAndColumn(Location location) const; + void addComment(Location begin, Location end, CommentPlacement placement); + void skipCommentTokens(Token& token); + + typedef std::stack Nodes; + Nodes nodes_; + Errors errors_; + std::string document_; + Location begin_; + Location end_; + Location current_; + Location lastValueEnd_; + Value* lastValue_; + std::string commentsBefore_; + int stackDepth_; + + OurFeatures const features_; + bool collectComments_; +}; // OurReader + +// complete copy of Read impl, for OurReader + +OurReader::OurReader(OurFeatures const& features) + : errors_(), document_(), begin_(), end_(), current_(), lastValueEnd_(), + lastValue_(), commentsBefore_(), features_(features), collectComments_() { +} + +bool OurReader::parse(const char* beginDoc, + const char* endDoc, + Value& root, + bool collectComments) { + if (!features_.allowComments_) { + collectComments = false; + } + + begin_ = beginDoc; + end_ = endDoc; + collectComments_ = collectComments; + current_ = begin_; + lastValueEnd_ = 0; + lastValue_ = 0; + commentsBefore_ = ""; + errors_.clear(); + while (!nodes_.empty()) + nodes_.pop(); + nodes_.push(&root); + + stackDepth_ = 0; + bool successful = readValue(); + Token token; + skipCommentTokens(token); + if (features_.failIfExtra_) { + if (token.type_ != tokenError && token.type_ != tokenEndOfStream) { + addError("Extra non-whitespace after JSON value.", token); + return false; + } + } + if (collectComments_ && !commentsBefore_.empty()) + root.setComment(commentsBefore_, commentAfter); + if (features_.strictRoot_) { + if (!root.isArray() && !root.isObject()) { + // Set error location to start of doc, ideally should be first token found + // in doc + token.type_ = tokenError; + token.start_ = beginDoc; + token.end_ = endDoc; + addError( + "A valid JSON document must be either an array or an object value.", + token); + return false; + } + } + return successful; +} + +bool OurReader::readValue() { + if (stackDepth_ >= features_.stackLimit_) throwRuntimeError("Exceeded stackLimit in readValue()."); + ++stackDepth_; + Token token; + skipCommentTokens(token); + bool successful = true; + + if (collectComments_ && !commentsBefore_.empty()) { + currentValue().setComment(commentsBefore_, commentBefore); + commentsBefore_ = ""; + } + + switch (token.type_) { + case tokenObjectBegin: + successful = readObject(token); + break; + case tokenArrayBegin: + successful = readArray(token); + break; + case tokenNumber: + successful = decodeNumber(token); + break; + case tokenString: + successful = decodeString(token); + break; + case tokenTrue: + { + Value v(true); + currentValue().swapPayload(v); + } + break; + case tokenFalse: + { + Value v(false); + currentValue().swapPayload(v); + } + break; + case tokenNull: + { + Value v; + currentValue().swapPayload(v); + } + break; + case tokenNaN: + { + Value v(std::numeric_limits::quiet_NaN()); + currentValue().swapPayload(v); + } + break; + case tokenPosInf: + { + Value v(std::numeric_limits::infinity()); + currentValue().swapPayload(v); + } + break; + case tokenNegInf: + { + Value v(-std::numeric_limits::infinity()); + currentValue().swapPayload(v); + } + break; + case tokenArraySeparator: + case tokenObjectEnd: + case tokenArrayEnd: + if (features_.allowDroppedNullPlaceholders_) { + // "Un-read" the current token and mark the current value as a null + // token. + current_--; + Value v; + currentValue().swapPayload(v); + break; + } // else, fall through ... + default: + return addError("Syntax error: value, object or array expected.", token); + } + + if (collectComments_) { + lastValueEnd_ = current_; + lastValue_ = ¤tValue(); + } + + --stackDepth_; + return successful; +} + +void OurReader::skipCommentTokens(Token& token) { + if (features_.allowComments_) { + do { + readToken(token); + } while (token.type_ == tokenComment); + } else { + readToken(token); + } +} + +bool OurReader::readToken(Token& token) { + skipSpaces(); + token.start_ = current_; + Char c = getNextChar(); + bool ok = true; + switch (c) { + case '{': + token.type_ = tokenObjectBegin; + break; + case '}': + token.type_ = tokenObjectEnd; + break; + case '[': + token.type_ = tokenArrayBegin; + break; + case ']': + token.type_ = tokenArrayEnd; + break; + case '"': + token.type_ = tokenString; + ok = readString(); + break; + case '\'': + if (features_.allowSingleQuotes_) { + token.type_ = tokenString; + ok = readStringSingleQuote(); + break; + } // else continue + case '/': + token.type_ = tokenComment; + ok = readComment(); + break; + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + token.type_ = tokenNumber; + readNumber(false); + break; + case '-': + if (readNumber(true)) { + token.type_ = tokenNumber; + } else { + token.type_ = tokenNegInf; + ok = features_.allowSpecialFloats_ && match("nfinity", 7); + } + break; + case 't': + token.type_ = tokenTrue; + ok = match("rue", 3); + break; + case 'f': + token.type_ = tokenFalse; + ok = match("alse", 4); + break; + case 'n': + token.type_ = tokenNull; + ok = match("ull", 3); + break; + case 'N': + if (features_.allowSpecialFloats_) { + token.type_ = tokenNaN; + ok = match("aN", 2); + } else { + ok = false; + } + break; + case 'I': + if (features_.allowSpecialFloats_) { + token.type_ = tokenPosInf; + ok = match("nfinity", 7); + } else { + ok = false; + } + break; + case ',': + token.type_ = tokenArraySeparator; + break; + case ':': + token.type_ = tokenMemberSeparator; + break; + case 0: + token.type_ = tokenEndOfStream; + break; + default: + ok = false; + break; + } + if (!ok) + token.type_ = tokenError; + token.end_ = current_; + return true; +} + +void OurReader::skipSpaces() { + while (current_ != end_) { + Char c = *current_; + if (c == ' ' || c == '\t' || c == '\r' || c == '\n') + ++current_; + else + break; + } +} + +bool OurReader::match(Location pattern, int patternLength) { + if (end_ - current_ < patternLength) + return false; + int index = patternLength; + while (index--) + if (current_[index] != pattern[index]) + return false; + current_ += patternLength; + return true; +} + +bool OurReader::readComment() { + Location commentBegin = current_ - 1; + Char c = getNextChar(); + bool successful = false; + if (c == '*') + successful = readCStyleComment(); + else if (c == '/') + successful = readCppStyleComment(); + if (!successful) + return false; + + if (collectComments_) { + CommentPlacement placement = commentBefore; + if (lastValueEnd_ && !containsNewLine(lastValueEnd_, commentBegin)) { + if (c != '*' || !containsNewLine(commentBegin, current_)) + placement = commentAfterOnSameLine; + } + + addComment(commentBegin, current_, placement); + } + return true; +} + +void +OurReader::addComment(Location begin, Location end, CommentPlacement placement) { + assert(collectComments_); + const std::string& normalized = normalizeEOL(begin, end); + if (placement == commentAfterOnSameLine) { + assert(lastValue_ != 0); + lastValue_->setComment(normalized, placement); + } else { + commentsBefore_ += normalized; + } +} + +bool OurReader::readCStyleComment() { + while (current_ != end_) { + Char c = getNextChar(); + if (c == '*' && *current_ == '/') + break; + } + return getNextChar() == '/'; +} + +bool OurReader::readCppStyleComment() { + while (current_ != end_) { + Char c = getNextChar(); + if (c == '\n') + break; + if (c == '\r') { + // Consume DOS EOL. It will be normalized in addComment. + if (current_ != end_ && *current_ == '\n') + getNextChar(); + // Break on Moc OS 9 EOL. + break; + } + } + return true; +} + +bool OurReader::readNumber(bool checkInf) { + const char *p = current_; + if (checkInf && p != end_ && *p == 'I') { + current_ = ++p; + return false; + } + char c = '0'; // stopgap for already consumed character + // integral part + while (c >= '0' && c <= '9') + c = (current_ = p) < end_ ? *p++ : 0; + // fractional part + if (c == '.') { + c = (current_ = p) < end_ ? *p++ : 0; + while (c >= '0' && c <= '9') + c = (current_ = p) < end_ ? *p++ : 0; + } + // exponential part + if (c == 'e' || c == 'E') { + c = (current_ = p) < end_ ? *p++ : 0; + if (c == '+' || c == '-') + c = (current_ = p) < end_ ? *p++ : 0; + while (c >= '0' && c <= '9') + c = (current_ = p) < end_ ? *p++ : 0; + } + return true; +} +bool OurReader::readString() { + Char c = 0; + while (current_ != end_) { + c = getNextChar(); + if (c == '\\') + getNextChar(); + else if (c == '"') + break; + } + return c == '"'; +} + + +bool OurReader::readStringSingleQuote() { + Char c = 0; + while (current_ != end_) { + c = getNextChar(); + if (c == '\\') + getNextChar(); + else if (c == '\'') + break; + } + return c == '\''; +} + +bool OurReader::readObject(Token& /*tokenStart*/) { + Token tokenName; + std::string name; + Value init(objectValue); + currentValue().swapPayload(init); + while (readToken(tokenName)) { + bool initialTokenOk = true; + while (tokenName.type_ == tokenComment && initialTokenOk) + initialTokenOk = readToken(tokenName); + if (!initialTokenOk) + break; + if (tokenName.type_ == tokenObjectEnd && name.empty()) // empty object + return true; + name = ""; + if (tokenName.type_ == tokenString) { + if (!decodeString(tokenName, name)) + return recoverFromError(tokenObjectEnd); + } else if (tokenName.type_ == tokenNumber && features_.allowNumericKeys_) { + Value numberName; + if (!decodeNumber(tokenName, numberName)) + return recoverFromError(tokenObjectEnd); + name = numberName.asString(); + } else { + break; + } + + Token colon; + if (!readToken(colon) || colon.type_ != tokenMemberSeparator) { + return addErrorAndRecover( + "Missing ':' after object member name", colon, tokenObjectEnd); + } + if (name.length() >= (1U<<30)) throwRuntimeError("keylength >= 2^30"); + if (features_.rejectDupKeys_ && currentValue().isMember(name)) { + std::string msg = "Duplicate key: '" + name + "'"; + return addErrorAndRecover( + msg, tokenName, tokenObjectEnd); + } + Value& value = currentValue()[name]; + nodes_.push(&value); + bool ok = readValue(); + nodes_.pop(); + if (!ok) // error already set + return recoverFromError(tokenObjectEnd); + + Token comma; + if (!readToken(comma) || + (comma.type_ != tokenObjectEnd && comma.type_ != tokenArraySeparator && + comma.type_ != tokenComment)) { + return addErrorAndRecover( + "Missing ',' or '}' in object declaration", comma, tokenObjectEnd); + } + bool finalizeTokenOk = true; + while (comma.type_ == tokenComment && finalizeTokenOk) + finalizeTokenOk = readToken(comma); + if (comma.type_ == tokenObjectEnd) + return true; + } + return addErrorAndRecover( + "Missing '}' or object member name", tokenName, tokenObjectEnd); +} + +bool OurReader::readArray(Token& /*tokenStart*/) { + Value init(arrayValue); + currentValue().swapPayload(init); + skipSpaces(); + if (*current_ == ']') // empty array + { + Token endArray; + readToken(endArray); + return true; + } + int index = 0; + for (;;) { + Value& value = currentValue()[index++]; + nodes_.push(&value); + bool ok = readValue(); + nodes_.pop(); + if (!ok) // error already set + return recoverFromError(tokenArrayEnd); + + Token token; + // Accept Comment after last item in the array. + ok = readToken(token); + while (token.type_ == tokenComment && ok) { + ok = readToken(token); + } + bool badTokenType = + (token.type_ != tokenArraySeparator && token.type_ != tokenArrayEnd); + if (!ok || badTokenType) { + return addErrorAndRecover( + "Missing ',' or ']' in array declaration", token, tokenArrayEnd); + } + if (token.type_ == tokenArrayEnd) + break; + } + return true; +} + +bool OurReader::decodeNumber(Token& token) { + Value decoded; + if (!decodeNumber(token, decoded)) + return false; + currentValue().swapPayload(decoded); + return true; +} + +bool OurReader::decodeNumber(Token& token, Value& decoded) { + // Attempts to parse the number as an integer. If the number is + // larger than the maximum supported value of an integer then + // we decode the number as a double. + Location current = token.start_; + bool isNegative = *current == '-'; + if (isNegative) + ++current; + // TODO: Help the compiler do the div and mod at compile time or get rid of them. + Value::LargestUInt maxIntegerValue = + isNegative ? Value::LargestUInt(-Value::minLargestInt) + : Value::maxLargestUInt; + Value::LargestUInt threshold = maxIntegerValue / 10; + Value::LargestUInt value = 0; + while (current < token.end_) { + Char c = *current++; + if (c < '0' || c > '9') + return decodeDouble(token, decoded); + Value::UInt digit(c - '0'); + if (value >= threshold) { + // We've hit or exceeded the max value divided by 10 (rounded down). If + // a) we've only just touched the limit, b) this is the last digit, and + // c) it's small enough to fit in that rounding delta, we're okay. + // Otherwise treat this number as a double to avoid overflow. + if (value > threshold || current != token.end_ || + digit > maxIntegerValue % 10) { + return decodeDouble(token, decoded); + } + } + value = value * 10 + digit; + } + if (isNegative) + decoded = -Value::LargestInt(value); + else if (value <= Value::LargestUInt(Value::maxInt)) + decoded = Value::LargestInt(value); + else + decoded = value; + return true; +} + +bool OurReader::decodeDouble(Token& token) { + Value decoded; + if (!decodeDouble(token, decoded)) + return false; + currentValue().swapPayload(decoded); + return true; +} + +bool OurReader::decodeDouble(Token& token, Value& decoded) { + double value = 0; + std::string buffer( token.start_, token.end_ ); + std::istringstream is(buffer); + if (!(is >> value)) + return addError("'" + std::string(token.start_, token.end_) + + "' is not a number.", + token); + decoded = value; + return true; +} + +bool OurReader::decodeString(Token& token) { + std::string decoded_string; + if (!decodeString(token, decoded_string)) + return false; + Value decoded(decoded_string); + currentValue().swapPayload(decoded); + return true; +} + +bool OurReader::decodeString(Token& token, std::string& decoded) { + decoded.reserve(token.end_ - token.start_ - 2); + Location current = token.start_ + 1; // skip '"' + Location end = token.end_ - 1; // do not include '"' + while (current != end) { + Char c = *current++; + if (c == '"') + break; + else if (c == '\\') { + if (current == end) + return addError("Empty escape sequence in string", token, current); + Char escape = *current++; + switch (escape) { + case '"': + decoded += '"'; + break; + case '/': + decoded += '/'; + break; + case '\\': + decoded += '\\'; + break; + case 'b': + decoded += '\b'; + break; + case 'f': + decoded += '\f'; + break; + case 'n': + decoded += '\n'; + break; + case 'r': + decoded += '\r'; + break; + case 't': + decoded += '\t'; + break; + case 'u': { + unsigned int unicode; + if (!decodeUnicodeCodePoint(token, current, end, unicode)) + return false; + decoded += codePointToUTF8(unicode); + } break; + default: + return addError("Bad escape sequence in string", token, current); + } + } else { + decoded += c; + } + } + return true; +} + +bool OurReader::decodeUnicodeCodePoint(Token& token, + Location& current, + Location end, + unsigned int& unicode) { + + if (!decodeUnicodeEscapeSequence(token, current, end, unicode)) + return false; + if (unicode >= 0xD800 && unicode <= 0xDBFF) { + // surrogate pairs + if (end - current < 6) + return addError( + "additional six characters expected to parse unicode surrogate pair.", + token, + current); + unsigned int surrogatePair; + if (*(current++) == '\\' && *(current++) == 'u') { + if (decodeUnicodeEscapeSequence(token, current, end, surrogatePair)) { + unicode = 0x10000 + ((unicode & 0x3FF) << 10) + (surrogatePair & 0x3FF); + } else + return false; + } else + return addError("expecting another \\u token to begin the second half of " + "a unicode surrogate pair", + token, + current); + } + return true; +} + +bool OurReader::decodeUnicodeEscapeSequence(Token& token, + Location& current, + Location end, + unsigned int& unicode) { + if (end - current < 4) + return addError( + "Bad unicode escape sequence in string: four digits expected.", + token, + current); + unicode = 0; + for (int index = 0; index < 4; ++index) { + Char c = *current++; + unicode *= 16; + if (c >= '0' && c <= '9') + unicode += c - '0'; + else if (c >= 'a' && c <= 'f') + unicode += c - 'a' + 10; + else if (c >= 'A' && c <= 'F') + unicode += c - 'A' + 10; + else + return addError( + "Bad unicode escape sequence in string: hexadecimal digit expected.", + token, + current); + } + return true; +} + +bool +OurReader::addError(const std::string& message, Token& token, Location extra) { + ErrorInfo info; + info.token_ = token; + info.message_ = message; + info.extra_ = extra; + errors_.push_back(info); + return false; +} + +bool OurReader::recoverFromError(TokenType skipUntilToken) { + int errorCount = int(errors_.size()); + Token skip; + for (;;) { + if (!readToken(skip)) + errors_.resize(errorCount); // discard errors caused by recovery + if (skip.type_ == skipUntilToken || skip.type_ == tokenEndOfStream) + break; + } + errors_.resize(errorCount); + return false; +} + +bool OurReader::addErrorAndRecover(const std::string& message, + Token& token, + TokenType skipUntilToken) { + addError(message, token); + return recoverFromError(skipUntilToken); +} + +Value& OurReader::currentValue() { return *(nodes_.top()); } + +OurReader::Char OurReader::getNextChar() { + if (current_ == end_) + return 0; + return *current_++; +} + +void OurReader::getLocationLineAndColumn(Location location, + int& line, + int& column) const { + Location current = begin_; + Location lastLineStart = current; + line = 0; + while (current < location && current != end_) { + Char c = *current++; + if (c == '\r') { + if (*current == '\n') + ++current; + lastLineStart = current; + ++line; + } else if (c == '\n') { + lastLineStart = current; + ++line; + } + } + // column & line start at 1 + column = int(location - lastLineStart) + 1; + ++line; +} + +std::string OurReader::getLocationLineAndColumn(Location location) const { + int line, column; + getLocationLineAndColumn(location, line, column); + char buffer[18 + 16 + 16 + 1]; + snprintf(buffer, sizeof(buffer), "Line %d, Column %d", line, column); + return buffer; +} + +std::string OurReader::getFormattedErrorMessages() const { + std::string formattedMessage; + for (Errors::const_iterator itError = errors_.begin(); + itError != errors_.end(); + ++itError) { + const ErrorInfo& error = *itError; + formattedMessage += + "* " + getLocationLineAndColumn(error.token_.start_) + "\n"; + formattedMessage += " " + error.message_ + "\n"; + if (error.extra_) + formattedMessage += + "See " + getLocationLineAndColumn(error.extra_) + " for detail.\n"; + } + return formattedMessage; +} + + +class OurCharReader : public CharReader { + bool const collectComments_; + OurReader reader_; +public: + OurCharReader( + bool collectComments, + OurFeatures const& features) + : collectComments_(collectComments) + , reader_(features) + {} + virtual bool parse( + char const* beginDoc, char const* endDoc, + Value* root, std::string* errs) { + bool ok = reader_.parse(beginDoc, endDoc, *root, collectComments_); + if (errs) { + *errs = reader_.getFormattedErrorMessages(); + } + return ok; + } +}; + +CharReaderBuilder::CharReaderBuilder() +{ + setDefaults(&settings_); +} +CharReaderBuilder::~CharReaderBuilder() +{} +CharReader* CharReaderBuilder::newCharReader() const +{ + bool collectComments = settings_["collectComments"].asBool(); + OurFeatures features = OurFeatures::all(); + features.allowComments_ = settings_["allowComments"].asBool(); + features.strictRoot_ = settings_["strictRoot"].asBool(); + features.allowDroppedNullPlaceholders_ = settings_["allowDroppedNullPlaceholders"].asBool(); + features.allowNumericKeys_ = settings_["allowNumericKeys"].asBool(); + features.allowSingleQuotes_ = settings_["allowSingleQuotes"].asBool(); + features.stackLimit_ = settings_["stackLimit"].asInt(); + features.failIfExtra_ = settings_["failIfExtra"].asBool(); + features.rejectDupKeys_ = settings_["rejectDupKeys"].asBool(); + features.allowSpecialFloats_ = settings_["allowSpecialFloats"].asBool(); + return new OurCharReader(collectComments, features); +} +static void getValidReaderKeys(std::set* valid_keys) +{ + valid_keys->clear(); + valid_keys->insert("collectComments"); + valid_keys->insert("allowComments"); + valid_keys->insert("strictRoot"); + valid_keys->insert("allowDroppedNullPlaceholders"); + valid_keys->insert("allowNumericKeys"); + valid_keys->insert("allowSingleQuotes"); + valid_keys->insert("stackLimit"); + valid_keys->insert("failIfExtra"); + valid_keys->insert("rejectDupKeys"); + valid_keys->insert("allowSpecialFloats"); +} +bool CharReaderBuilder::validate(Json::Value* invalid) const +{ + Json::Value my_invalid; + if (!invalid) invalid = &my_invalid; // so we do not need to test for NULL + Json::Value& inv = *invalid; + std::set valid_keys; + getValidReaderKeys(&valid_keys); + Value::Members keys = settings_.getMemberNames(); + size_t n = keys.size(); + for (size_t i = 0; i < n; ++i) { + std::string const& key = keys[i]; + if (valid_keys.find(key) == valid_keys.end()) { + inv[key] = settings_[key]; + } + } + return 0u == inv.size(); +} +Value& CharReaderBuilder::operator[](std::string key) +{ + return settings_[key]; +} +// static +void CharReaderBuilder::strictMode(Json::Value* settings) +{ +//! [CharReaderBuilderStrictMode] + (*settings)["allowComments"] = false; + (*settings)["strictRoot"] = true; + (*settings)["allowDroppedNullPlaceholders"] = false; + (*settings)["allowNumericKeys"] = false; + (*settings)["allowSingleQuotes"] = false; + (*settings)["failIfExtra"] = true; + (*settings)["rejectDupKeys"] = true; + (*settings)["allowSpecialFloats"] = false; +//! [CharReaderBuilderStrictMode] +} +// static +void CharReaderBuilder::setDefaults(Json::Value* settings) +{ +//! [CharReaderBuilderDefaults] + (*settings)["collectComments"] = true; + (*settings)["allowComments"] = true; + (*settings)["strictRoot"] = false; + (*settings)["allowDroppedNullPlaceholders"] = false; + (*settings)["allowNumericKeys"] = false; + (*settings)["allowSingleQuotes"] = false; + (*settings)["stackLimit"] = 1000; + (*settings)["failIfExtra"] = false; + (*settings)["rejectDupKeys"] = false; + (*settings)["allowSpecialFloats"] = false; +//! [CharReaderBuilderDefaults] +} + +////////////////////////////////// +// global functions + +bool parseFromStream( + CharReader::Factory const& fact, std::istream& sin, + Value* root, std::string* errs) +{ + std::ostringstream ssin; + ssin << sin.rdbuf(); + std::string doc = ssin.str(); + char const* begin = doc.data(); + char const* end = begin + doc.size(); + // Note that we do not actually need a null-terminator. + CharReaderPtr const reader(fact.newCharReader()); + return reader->parse(begin, end, root, errs); +} + +std::istream& operator>>(std::istream& sin, Value& root) { + CharReaderBuilder b; + std::string errs; + bool ok = parseFromStream(b, sin, &root, &errs); + if (!ok) { + fprintf(stderr, + "Error from reader: %s", + errs.c_str()); + + throwRuntimeError("reader error"); + } + return sin; +} + +} // namespace Json + +// ////////////////////////////////////////////////////////////////////// +// End of content of file: src/lib_json/json_reader.cpp +// ////////////////////////////////////////////////////////////////////// + + + + + + +// ////////////////////////////////////////////////////////////////////// +// Beginning of content of file: src/lib_json/json_valueiterator.inl +// ////////////////////////////////////////////////////////////////////// + +// Copyright 2007-2010 Baptiste Lepilleur +// Distributed under MIT license, or public domain if desired and +// recognized in your jurisdiction. +// See file LICENSE for detail or copy at http://jsoncpp.sourceforge.net/LICENSE + +// included by json_value.cpp + +namespace Json { + +// ////////////////////////////////////////////////////////////////// +// ////////////////////////////////////////////////////////////////// +// ////////////////////////////////////////////////////////////////// +// class ValueIteratorBase +// ////////////////////////////////////////////////////////////////// +// ////////////////////////////////////////////////////////////////// +// ////////////////////////////////////////////////////////////////// + +ValueIteratorBase::ValueIteratorBase() + : current_(), isNull_(true) { +} + +ValueIteratorBase::ValueIteratorBase( + const Value::ObjectValues::iterator& current) + : current_(current), isNull_(false) {} + +Value& ValueIteratorBase::deref() const { + return current_->second; +} + +void ValueIteratorBase::increment() { + ++current_; +} + +void ValueIteratorBase::decrement() { + --current_; +} + +ValueIteratorBase::difference_type +ValueIteratorBase::computeDistance(const SelfType& other) const { +#ifdef JSON_USE_CPPTL_SMALLMAP + return other.current_ - current_; +#else + // Iterator for null value are initialized using the default + // constructor, which initialize current_ to the default + // std::map::iterator. As begin() and end() are two instance + // of the default std::map::iterator, they can not be compared. + // To allow this, we handle this comparison specifically. + if (isNull_ && other.isNull_) { + return 0; + } + + // Usage of std::distance is not portable (does not compile with Sun Studio 12 + // RogueWave STL, + // which is the one used by default). + // Using a portable hand-made version for non random iterator instead: + // return difference_type( std::distance( current_, other.current_ ) ); + difference_type myDistance = 0; + for (Value::ObjectValues::iterator it = current_; it != other.current_; + ++it) { + ++myDistance; + } + return myDistance; +#endif +} + +bool ValueIteratorBase::isEqual(const SelfType& other) const { + if (isNull_) { + return other.isNull_; + } + return current_ == other.current_; +} + +void ValueIteratorBase::copy(const SelfType& other) { + current_ = other.current_; + isNull_ = other.isNull_; +} + +Value ValueIteratorBase::key() const { + const Value::CZString czstring = (*current_).first; + if (czstring.data()) { + if (czstring.isStaticString()) + return Value(StaticString(czstring.data())); + return Value(czstring.data(), czstring.data() + czstring.length()); + } + return Value(czstring.index()); +} + +UInt ValueIteratorBase::index() const { + const Value::CZString czstring = (*current_).first; + if (!czstring.data()) + return czstring.index(); + return Value::UInt(-1); +} + +std::string ValueIteratorBase::name() const { + char const* keey; + char const* end; + keey = memberName(&end); + if (!keey) return std::string(); + return std::string(keey, end); +} + +char const* ValueIteratorBase::memberName() const { + const char* cname = (*current_).first.data(); + return cname ? cname : ""; +} + +char const* ValueIteratorBase::memberName(char const** end) const { + const char* cname = (*current_).first.data(); + if (!cname) { + *end = NULL; + return NULL; + } + *end = cname + (*current_).first.length(); + return cname; +} + +// ////////////////////////////////////////////////////////////////// +// ////////////////////////////////////////////////////////////////// +// ////////////////////////////////////////////////////////////////// +// class ValueConstIterator +// ////////////////////////////////////////////////////////////////// +// ////////////////////////////////////////////////////////////////// +// ////////////////////////////////////////////////////////////////// + +ValueConstIterator::ValueConstIterator() {} + +ValueConstIterator::ValueConstIterator( + const Value::ObjectValues::iterator& current) + : ValueIteratorBase(current) {} + +ValueConstIterator& ValueConstIterator:: +operator=(const ValueIteratorBase& other) { + copy(other); + return *this; +} + +// ////////////////////////////////////////////////////////////////// +// ////////////////////////////////////////////////////////////////// +// ////////////////////////////////////////////////////////////////// +// class ValueIterator +// ////////////////////////////////////////////////////////////////// +// ////////////////////////////////////////////////////////////////// +// ////////////////////////////////////////////////////////////////// + +ValueIterator::ValueIterator() {} + +ValueIterator::ValueIterator(const Value::ObjectValues::iterator& current) + : ValueIteratorBase(current) {} + +ValueIterator::ValueIterator(const ValueConstIterator& other) + : ValueIteratorBase(other) {} + +ValueIterator::ValueIterator(const ValueIterator& other) + : ValueIteratorBase(other) {} + +ValueIterator& ValueIterator::operator=(const SelfType& other) { + copy(other); + return *this; +} + +} // namespace Json + +// ////////////////////////////////////////////////////////////////////// +// End of content of file: src/lib_json/json_valueiterator.inl +// ////////////////////////////////////////////////////////////////////// + + + + + + +// ////////////////////////////////////////////////////////////////////// +// Beginning of content of file: src/lib_json/json_value.cpp +// ////////////////////////////////////////////////////////////////////// + +// Copyright 2011 Baptiste Lepilleur +// Distributed under MIT license, or public domain if desired and +// recognized in your jurisdiction. +// See file LICENSE for detail or copy at http://jsoncpp.sourceforge.net/LICENSE + +#if !defined(JSON_IS_AMALGAMATION) +#include +#include +#include +#endif // if !defined(JSON_IS_AMALGAMATION) +#include +#include +#include +#include +#include +#ifdef JSON_USE_CPPTL +#include +#endif +#include // size_t +#include // min() +#if defined(__BORLANDC__) +#include +#endif +#define JSON_ASSERT_UNREACHABLE assert(false) + +namespace Json { + +// This is a walkaround to avoid the static initialization of Value::null. +// kNull must be word-aligned to avoid crashing on ARM. We use an alignment of +// 8 (instead of 4) as a bit of future-proofing. +#if defined(__ARMEL__) +#define ALIGNAS(byte_alignment) __attribute__((aligned(byte_alignment))) +#else +// This exists for binary compatibility only. Use nullRef. +const Value Value::null; +#define ALIGNAS(byte_alignment) +#endif +static const unsigned char ALIGNAS(8) kNull[sizeof(Value)] = { 0 }; +const unsigned char& kNullRef = kNull[0]; +const Value& Value::nullRef = reinterpret_cast(kNullRef); + +const Int Value::minInt = Int(~(UInt(-1) / 2)); +const Int Value::maxInt = Int(UInt(-1) / 2); +const UInt Value::maxUInt = UInt(-1); +#if defined(JSON_HAS_INT64) +const Int64 Value::minInt64 = Int64(~(UInt64(-1) / 2)); +const Int64 Value::maxInt64 = Int64(UInt64(-1) / 2); +const UInt64 Value::maxUInt64 = UInt64(-1); +// The constant is hard-coded because some compiler have trouble +// converting Value::maxUInt64 to a double correctly (AIX/xlC). +// Assumes that UInt64 is a 64 bits integer. +static const double maxUInt64AsDouble = 18446744073709551615.0; +#endif // defined(JSON_HAS_INT64) +const LargestInt Value::minLargestInt = LargestInt(~(LargestUInt(-1) / 2)); +const LargestInt Value::maxLargestInt = LargestInt(LargestUInt(-1) / 2); +const LargestUInt Value::maxLargestUInt = LargestUInt(-1); + +#if !defined(JSON_USE_INT64_DOUBLE_CONVERSION) +template +static inline bool InRange(double d, T min, U max) { + return d >= min && d <= max; +} +#else // if !defined(JSON_USE_INT64_DOUBLE_CONVERSION) +static inline double integerToDouble(Json::UInt64 value) { + return static_cast(Int64(value / 2)) * 2.0 + Int64(value & 1); +} + +template static inline double integerToDouble(T value) { + return static_cast(value); +} + +template +static inline bool InRange(double d, T min, U max) { + return d >= integerToDouble(min) && d <= integerToDouble(max); +} +#endif // if !defined(JSON_USE_INT64_DOUBLE_CONVERSION) + +/** Duplicates the specified string value. + * @param value Pointer to the string to duplicate. Must be zero-terminated if + * length is "unknown". + * @param length Length of the value. if equals to unknown, then it will be + * computed using strlen(value). + * @return Pointer on the duplicate instance of string. + */ +static inline char* duplicateStringValue(const char* value, + size_t length) { + // Avoid an integer overflow in the call to malloc below by limiting length + // to a sane value. + if (length >= (size_t)Value::maxInt) + length = Value::maxInt - 1; + + char* newString = static_cast(malloc(length + 1)); + if (newString == NULL) { + throwRuntimeError( + "in Json::Value::duplicateStringValue(): " + "Failed to allocate string value buffer"); + } + memcpy(newString, value, length); + newString[length] = 0; + return newString; +} + +/* Record the length as a prefix. + */ +static inline char* duplicateAndPrefixStringValue( + const char* value, + unsigned int length) +{ + // Avoid an integer overflow in the call to malloc below by limiting length + // to a sane value. + JSON_ASSERT_MESSAGE(length <= (unsigned)Value::maxInt - sizeof(unsigned) - 1U, + "in Json::Value::duplicateAndPrefixStringValue(): " + "length too big for prefixing"); + unsigned actualLength = length + static_cast(sizeof(unsigned)) + 1U; + char* newString = static_cast(malloc(actualLength)); + if (newString == 0) { + throwRuntimeError( + "in Json::Value::duplicateAndPrefixStringValue(): " + "Failed to allocate string value buffer"); + } + *reinterpret_cast(newString) = length; + memcpy(newString + sizeof(unsigned), value, length); + newString[actualLength - 1U] = 0; // to avoid buffer over-run accidents by users later + return newString; +} +inline static void decodePrefixedString( + bool isPrefixed, char const* prefixed, + unsigned* length, char const** value) +{ + if (!isPrefixed) { + *length = static_cast(strlen(prefixed)); + *value = prefixed; + } else { + *length = *reinterpret_cast(prefixed); + *value = prefixed + sizeof(unsigned); + } +} +/** Free the string duplicated by duplicateStringValue()/duplicateAndPrefixStringValue(). + */ +static inline void releaseStringValue(char* value) { free(value); } + +} // namespace Json + +// ////////////////////////////////////////////////////////////////// +// ////////////////////////////////////////////////////////////////// +// ////////////////////////////////////////////////////////////////// +// ValueInternals... +// ////////////////////////////////////////////////////////////////// +// ////////////////////////////////////////////////////////////////// +// ////////////////////////////////////////////////////////////////// +#if !defined(JSON_IS_AMALGAMATION) + +#include "json_valueiterator.inl" +#endif // if !defined(JSON_IS_AMALGAMATION) + +namespace Json { + +Exception::Exception(std::string const& msg) + : msg_(msg) +{} +Exception::~Exception() throw() +{} +char const* Exception::what() const throw() +{ + return msg_.c_str(); +} +RuntimeError::RuntimeError(std::string const& msg) + : Exception(msg) +{} +LogicError::LogicError(std::string const& msg) + : Exception(msg) +{} +JSONCPP_NORETURN void throwRuntimeError(std::string const& msg) +{ + throw RuntimeError(msg); +} +JSONCPP_NORETURN void throwLogicError(std::string const& msg) +{ + throw LogicError(msg); +} + +// ////////////////////////////////////////////////////////////////// +// ////////////////////////////////////////////////////////////////// +// ////////////////////////////////////////////////////////////////// +// class Value::CommentInfo +// ////////////////////////////////////////////////////////////////// +// ////////////////////////////////////////////////////////////////// +// ////////////////////////////////////////////////////////////////// + +Value::CommentInfo::CommentInfo() : comment_(0) {} + +Value::CommentInfo::~CommentInfo() { + if (comment_) + releaseStringValue(comment_); +} + +void Value::CommentInfo::setComment(const char* text, size_t len) { + if (comment_) { + releaseStringValue(comment_); + comment_ = 0; + } + JSON_ASSERT(text != 0); + JSON_ASSERT_MESSAGE( + text[0] == '\0' || text[0] == '/', + "in Json::Value::setComment(): Comments must start with /"); + // It seems that /**/ style comments are acceptable as well. + comment_ = duplicateStringValue(text, len); +} + +// ////////////////////////////////////////////////////////////////// +// ////////////////////////////////////////////////////////////////// +// ////////////////////////////////////////////////////////////////// +// class Value::CZString +// ////////////////////////////////////////////////////////////////// +// ////////////////////////////////////////////////////////////////// +// ////////////////////////////////////////////////////////////////// + +// Notes: policy_ indicates if the string was allocated when +// a string is stored. + +Value::CZString::CZString(ArrayIndex aindex) : cstr_(0), index_(aindex) {} + +Value::CZString::CZString(char const* str, unsigned ulength, DuplicationPolicy allocate) + : cstr_(str) +{ + // allocate != duplicate + storage_.policy_ = allocate & 0x3; + storage_.length_ = ulength & 0x3FFFFFFF; +} + +Value::CZString::CZString(const CZString& other) + : cstr_(other.storage_.policy_ != noDuplication && other.cstr_ != 0 + ? duplicateStringValue(other.cstr_, other.storage_.length_) + : other.cstr_) +{ + storage_.policy_ = (other.cstr_ + ? (static_cast(other.storage_.policy_) == noDuplication + ? noDuplication : duplicate) + : static_cast(other.storage_.policy_)); + storage_.length_ = other.storage_.length_; +} + +Value::CZString::~CZString() { + if (cstr_ && storage_.policy_ == duplicate) + releaseStringValue(const_cast(cstr_)); +} + +void Value::CZString::swap(CZString& other) { + std::swap(cstr_, other.cstr_); + std::swap(index_, other.index_); +} + +Value::CZString& Value::CZString::operator=(CZString other) { + swap(other); + return *this; +} + +bool Value::CZString::operator<(const CZString& other) const { + if (!cstr_) return index_ < other.index_; + //return strcmp(cstr_, other.cstr_) < 0; + // Assume both are strings. + unsigned this_len = this->storage_.length_; + unsigned other_len = other.storage_.length_; + unsigned min_len = std::min(this_len, other_len); + int comp = memcmp(this->cstr_, other.cstr_, min_len); + if (comp < 0) return true; + if (comp > 0) return false; + return (this_len < other_len); +} + +bool Value::CZString::operator==(const CZString& other) const { + if (!cstr_) return index_ == other.index_; + //return strcmp(cstr_, other.cstr_) == 0; + // Assume both are strings. + unsigned this_len = this->storage_.length_; + unsigned other_len = other.storage_.length_; + if (this_len != other_len) return false; + int comp = memcmp(this->cstr_, other.cstr_, this_len); + return comp == 0; +} + +ArrayIndex Value::CZString::index() const { return index_; } + +//const char* Value::CZString::c_str() const { return cstr_; } +const char* Value::CZString::data() const { return cstr_; } +unsigned Value::CZString::length() const { return storage_.length_; } +bool Value::CZString::isStaticString() const { return storage_.policy_ == noDuplication; } + +// ////////////////////////////////////////////////////////////////// +// ////////////////////////////////////////////////////////////////// +// ////////////////////////////////////////////////////////////////// +// class Value::Value +// ////////////////////////////////////////////////////////////////// +// ////////////////////////////////////////////////////////////////// +// ////////////////////////////////////////////////////////////////// + +/*! \internal Default constructor initialization must be equivalent to: + * memset( this, 0, sizeof(Value) ) + * This optimization is used in ValueInternalMap fast allocator. + */ +Value::Value(ValueType vtype) { + initBasic(vtype); + switch (vtype) { + case nullValue: + break; + case intValue: + case uintValue: + value_.int_ = 0; + break; + case realValue: + value_.real_ = 0.0; + break; + case stringValue: + value_.string_ = 0; + break; + case arrayValue: + case objectValue: + value_.map_ = new ObjectValues(); + break; + case booleanValue: + value_.bool_ = false; + break; + default: + JSON_ASSERT_UNREACHABLE; + } +} + +Value::Value(Int value) { + initBasic(intValue); + value_.int_ = value; +} + +Value::Value(UInt value) { + initBasic(uintValue); + value_.uint_ = value; +} +#if defined(JSON_HAS_INT64) +Value::Value(Int64 value) { + initBasic(intValue); + value_.int_ = value; +} +Value::Value(UInt64 value) { + initBasic(uintValue); + value_.uint_ = value; +} +#endif // defined(JSON_HAS_INT64) + +Value::Value(double value) { + initBasic(realValue); + value_.real_ = value; +} + +Value::Value(const char* value) { + initBasic(stringValue, true); + value_.string_ = duplicateAndPrefixStringValue(value, static_cast(strlen(value))); +} + +Value::Value(const char* beginValue, const char* endValue) { + initBasic(stringValue, true); + value_.string_ = + duplicateAndPrefixStringValue(beginValue, static_cast(endValue - beginValue)); +} + +Value::Value(const std::string& value) { + initBasic(stringValue, true); + value_.string_ = + duplicateAndPrefixStringValue(value.data(), static_cast(value.length())); +} + +Value::Value(const StaticString& value) { + initBasic(stringValue); + value_.string_ = const_cast(value.c_str()); +} + +#ifdef JSON_USE_CPPTL +Value::Value(const CppTL::ConstString& value) { + initBasic(stringValue, true); + value_.string_ = duplicateAndPrefixStringValue(value, static_cast(value.length())); +} +#endif + +Value::Value(bool value) { + initBasic(booleanValue); + value_.bool_ = value; +} + +Value::Value(Value const& other) + : type_(other.type_), allocated_(false) + , + comments_(0) +{ + switch (type_) { + case nullValue: + case intValue: + case uintValue: + case realValue: + case booleanValue: + value_ = other.value_; + break; + case stringValue: + if (other.value_.string_ && other.allocated_) { + unsigned len; + char const* str; + decodePrefixedString(other.allocated_, other.value_.string_, + &len, &str); + value_.string_ = duplicateAndPrefixStringValue(str, len); + allocated_ = true; + } else { + value_.string_ = other.value_.string_; + allocated_ = false; + } + break; + case arrayValue: + case objectValue: + value_.map_ = new ObjectValues(*other.value_.map_); + break; + default: + JSON_ASSERT_UNREACHABLE; + } + if (other.comments_) { + comments_ = new CommentInfo[numberOfCommentPlacement]; + for (int comment = 0; comment < numberOfCommentPlacement; ++comment) { + const CommentInfo& otherComment = other.comments_[comment]; + if (otherComment.comment_) + comments_[comment].setComment( + otherComment.comment_, strlen(otherComment.comment_)); + } + } +} + +Value::~Value() { + switch (type_) { + case nullValue: + case intValue: + case uintValue: + case realValue: + case booleanValue: + break; + case stringValue: + if (allocated_) + releaseStringValue(value_.string_); + break; + case arrayValue: + case objectValue: + delete value_.map_; + break; + default: + JSON_ASSERT_UNREACHABLE; + } + + if (comments_) + delete[] comments_; +} + +Value &Value::operator=(const Value &other) { + Value temp(other); + swap(temp); + return *this; +} + +void Value::swapPayload(Value& other) { + ValueType temp = type_; + type_ = other.type_; + other.type_ = temp; + std::swap(value_, other.value_); + int temp2 = allocated_; + allocated_ = other.allocated_; + other.allocated_ = temp2 & 0x1; +} + +void Value::swap(Value& other) { + swapPayload(other); + std::swap(comments_, other.comments_); +} + +ValueType Value::type() const { return type_; } + +int Value::compare(const Value& other) const { + if (*this < other) + return -1; + if (*this > other) + return 1; + return 0; +} + +bool Value::operator<(const Value& other) const { + int typeDelta = type_ - other.type_; + if (typeDelta) + return typeDelta < 0 ? true : false; + switch (type_) { + case nullValue: + return false; + case intValue: + return value_.int_ < other.value_.int_; + case uintValue: + return value_.uint_ < other.value_.uint_; + case realValue: + return value_.real_ < other.value_.real_; + case booleanValue: + return value_.bool_ < other.value_.bool_; + case stringValue: + { + if ((value_.string_ == 0) || (other.value_.string_ == 0)) { + if (other.value_.string_) return true; + else return false; + } + unsigned this_len; + unsigned other_len; + char const* this_str; + char const* other_str; + decodePrefixedString(this->allocated_, this->value_.string_, &this_len, &this_str); + decodePrefixedString(other.allocated_, other.value_.string_, &other_len, &other_str); + unsigned min_len = std::min(this_len, other_len); + int comp = memcmp(this_str, other_str, min_len); + if (comp < 0) return true; + if (comp > 0) return false; + return (this_len < other_len); + } + case arrayValue: + case objectValue: { + int delta = int(value_.map_->size() - other.value_.map_->size()); + if (delta) + return delta < 0; + return (*value_.map_) < (*other.value_.map_); + } + default: + JSON_ASSERT_UNREACHABLE; + } + return false; // unreachable +} + +bool Value::operator<=(const Value& other) const { return !(other < *this); } + +bool Value::operator>=(const Value& other) const { return !(*this < other); } + +bool Value::operator>(const Value& other) const { return other < *this; } + +bool Value::operator==(const Value& other) const { + // if ( type_ != other.type_ ) + // GCC 2.95.3 says: + // attempt to take address of bit-field structure member `Json::Value::type_' + // Beats me, but a temp solves the problem. + int temp = other.type_; + if (type_ != temp) + return false; + switch (type_) { + case nullValue: + return true; + case intValue: + return value_.int_ == other.value_.int_; + case uintValue: + return value_.uint_ == other.value_.uint_; + case realValue: + return value_.real_ == other.value_.real_; + case booleanValue: + return value_.bool_ == other.value_.bool_; + case stringValue: + { + if ((value_.string_ == 0) || (other.value_.string_ == 0)) { + return (value_.string_ == other.value_.string_); + } + unsigned this_len; + unsigned other_len; + char const* this_str; + char const* other_str; + decodePrefixedString(this->allocated_, this->value_.string_, &this_len, &this_str); + decodePrefixedString(other.allocated_, other.value_.string_, &other_len, &other_str); + if (this_len != other_len) return false; + int comp = memcmp(this_str, other_str, this_len); + return comp == 0; + } + case arrayValue: + case objectValue: + return value_.map_->size() == other.value_.map_->size() && + (*value_.map_) == (*other.value_.map_); + default: + JSON_ASSERT_UNREACHABLE; + } + return false; // unreachable +} + +bool Value::operator!=(const Value& other) const { return !(*this == other); } + +const char* Value::asCString() const { + JSON_ASSERT_MESSAGE(type_ == stringValue, + "in Json::Value::asCString(): requires stringValue"); + if (value_.string_ == 0) return 0; + unsigned this_len; + char const* this_str; + decodePrefixedString(this->allocated_, this->value_.string_, &this_len, &this_str); + return this_str; +} + +bool Value::getString(char const** str, char const** cend) const { + if (type_ != stringValue) return false; + if (value_.string_ == 0) return false; + unsigned length; + decodePrefixedString(this->allocated_, this->value_.string_, &length, str); + *cend = *str + length; + return true; +} + +std::string Value::asString() const { + switch (type_) { + case nullValue: + return ""; + case stringValue: + { + if (value_.string_ == 0) return ""; + unsigned this_len; + char const* this_str; + decodePrefixedString(this->allocated_, this->value_.string_, &this_len, &this_str); + return std::string(this_str, this_len); + } + case booleanValue: + return value_.bool_ ? "true" : "false"; + case intValue: + return valueToString(value_.int_); + case uintValue: + return valueToString(value_.uint_); + case realValue: + return valueToString(value_.real_); + default: + JSON_FAIL_MESSAGE("Type is not convertible to string"); + } +} + +#ifdef JSON_USE_CPPTL +CppTL::ConstString Value::asConstString() const { + unsigned len; + char const* str; + decodePrefixedString(allocated_, value_.string_, + &len, &str); + return CppTL::ConstString(str, len); +} +#endif + +Value::Int Value::asInt() const { + switch (type_) { + case intValue: + JSON_ASSERT_MESSAGE(isInt(), "LargestInt out of Int range"); + return Int(value_.int_); + case uintValue: + JSON_ASSERT_MESSAGE(isInt(), "LargestUInt out of Int range"); + return Int(value_.uint_); + case realValue: + JSON_ASSERT_MESSAGE(InRange(value_.real_, minInt, maxInt), + "double out of Int range"); + return Int(value_.real_); + case nullValue: + return 0; + case booleanValue: + return value_.bool_ ? 1 : 0; + default: + break; + } + JSON_FAIL_MESSAGE("Value is not convertible to Int."); +} + +Value::UInt Value::asUInt() const { + switch (type_) { + case intValue: + JSON_ASSERT_MESSAGE(isUInt(), "LargestInt out of UInt range"); + return UInt(value_.int_); + case uintValue: + JSON_ASSERT_MESSAGE(isUInt(), "LargestUInt out of UInt range"); + return UInt(value_.uint_); + case realValue: + JSON_ASSERT_MESSAGE(InRange(value_.real_, 0, maxUInt), + "double out of UInt range"); + return UInt(value_.real_); + case nullValue: + return 0; + case booleanValue: + return value_.bool_ ? 1 : 0; + default: + break; + } + JSON_FAIL_MESSAGE("Value is not convertible to UInt."); +} + +#if defined(JSON_HAS_INT64) + +Value::Int64 Value::asInt64() const { + switch (type_) { + case intValue: + return Int64(value_.int_); + case uintValue: + JSON_ASSERT_MESSAGE(isInt64(), "LargestUInt out of Int64 range"); + return Int64(value_.uint_); + case realValue: + JSON_ASSERT_MESSAGE(InRange(value_.real_, minInt64, maxInt64), + "double out of Int64 range"); + return Int64(value_.real_); + case nullValue: + return 0; + case booleanValue: + return value_.bool_ ? 1 : 0; + default: + break; + } + JSON_FAIL_MESSAGE("Value is not convertible to Int64."); +} + +Value::UInt64 Value::asUInt64() const { + switch (type_) { + case intValue: + JSON_ASSERT_MESSAGE(isUInt64(), "LargestInt out of UInt64 range"); + return UInt64(value_.int_); + case uintValue: + return UInt64(value_.uint_); + case realValue: + JSON_ASSERT_MESSAGE(InRange(value_.real_, 0, maxUInt64), + "double out of UInt64 range"); + return UInt64(value_.real_); + case nullValue: + return 0; + case booleanValue: + return value_.bool_ ? 1 : 0; + default: + break; + } + JSON_FAIL_MESSAGE("Value is not convertible to UInt64."); +} +#endif // if defined(JSON_HAS_INT64) + +LargestInt Value::asLargestInt() const { +#if defined(JSON_NO_INT64) + return asInt(); +#else + return asInt64(); +#endif +} + +LargestUInt Value::asLargestUInt() const { +#if defined(JSON_NO_INT64) + return asUInt(); +#else + return asUInt64(); +#endif +} + +double Value::asDouble() const { + switch (type_) { + case intValue: + return static_cast(value_.int_); + case uintValue: +#if !defined(JSON_USE_INT64_DOUBLE_CONVERSION) + return static_cast(value_.uint_); +#else // if !defined(JSON_USE_INT64_DOUBLE_CONVERSION) + return integerToDouble(value_.uint_); +#endif // if !defined(JSON_USE_INT64_DOUBLE_CONVERSION) + case realValue: + return value_.real_; + case nullValue: + return 0.0; + case booleanValue: + return value_.bool_ ? 1.0 : 0.0; + default: + break; + } + JSON_FAIL_MESSAGE("Value is not convertible to double."); +} + +float Value::asFloat() const { + switch (type_) { + case intValue: + return static_cast(value_.int_); + case uintValue: +#if !defined(JSON_USE_INT64_DOUBLE_CONVERSION) + return static_cast(value_.uint_); +#else // if !defined(JSON_USE_INT64_DOUBLE_CONVERSION) + return integerToDouble(value_.uint_); +#endif // if !defined(JSON_USE_INT64_DOUBLE_CONVERSION) + case realValue: + return static_cast(value_.real_); + case nullValue: + return 0.0; + case booleanValue: + return value_.bool_ ? 1.0f : 0.0f; + default: + break; + } + JSON_FAIL_MESSAGE("Value is not convertible to float."); +} + +bool Value::asBool() const { + switch (type_) { + case booleanValue: + return value_.bool_; + case nullValue: + return false; + case intValue: + return value_.int_ ? true : false; + case uintValue: + return value_.uint_ ? true : false; + case realValue: + // This is kind of strange. Not recommended. + return (value_.real_ != 0.0) ? true : false; + default: + break; + } + JSON_FAIL_MESSAGE("Value is not convertible to bool."); +} + +bool Value::isConvertibleTo(ValueType other) const { + switch (other) { + case nullValue: + return (isNumeric() && asDouble() == 0.0) || + (type_ == booleanValue && value_.bool_ == false) || + (type_ == stringValue && asString() == "") || + (type_ == arrayValue && value_.map_->size() == 0) || + (type_ == objectValue && value_.map_->size() == 0) || + type_ == nullValue; + case intValue: + return isInt() || + (type_ == realValue && InRange(value_.real_, minInt, maxInt)) || + type_ == booleanValue || type_ == nullValue; + case uintValue: + return isUInt() || + (type_ == realValue && InRange(value_.real_, 0, maxUInt)) || + type_ == booleanValue || type_ == nullValue; + case realValue: + return isNumeric() || type_ == booleanValue || type_ == nullValue; + case booleanValue: + return isNumeric() || type_ == booleanValue || type_ == nullValue; + case stringValue: + return isNumeric() || type_ == booleanValue || type_ == stringValue || + type_ == nullValue; + case arrayValue: + return type_ == arrayValue || type_ == nullValue; + case objectValue: + return type_ == objectValue || type_ == nullValue; + } + JSON_ASSERT_UNREACHABLE; + return false; +} + +/// Number of values in array or object +ArrayIndex Value::size() const { + switch (type_) { + case nullValue: + case intValue: + case uintValue: + case realValue: + case booleanValue: + case stringValue: + return 0; + case arrayValue: // size of the array is highest index + 1 + if (!value_.map_->empty()) { + ObjectValues::const_iterator itLast = value_.map_->end(); + --itLast; + return (*itLast).first.index() + 1; + } + return 0; + case objectValue: + return ArrayIndex(value_.map_->size()); + } + JSON_ASSERT_UNREACHABLE; + return 0; // unreachable; +} + +bool Value::empty() const { + if (isNull() || isArray() || isObject()) + return size() == 0u; + else + return false; +} + +bool Value::operator!() const { return isNull(); } + +void Value::clear() { + JSON_ASSERT_MESSAGE(type_ == nullValue || type_ == arrayValue || + type_ == objectValue, + "in Json::Value::clear(): requires complex value"); + switch (type_) { + case arrayValue: + case objectValue: + value_.map_->clear(); + break; + default: + break; + } +} + +void Value::resize(ArrayIndex newSize) { + JSON_ASSERT_MESSAGE(type_ == nullValue || type_ == arrayValue, + "in Json::Value::resize(): requires arrayValue"); + if (type_ == nullValue) + *this = Value(arrayValue); + ArrayIndex oldSize = size(); + if (newSize == 0) + clear(); + else if (newSize > oldSize) + (*this)[newSize - 1]; + else { + for (ArrayIndex index = newSize; index < oldSize; ++index) { + value_.map_->erase(index); + } + assert(size() == newSize); + } +} + +Value& Value::operator[](ArrayIndex index) { + JSON_ASSERT_MESSAGE( + type_ == nullValue || type_ == arrayValue, + "in Json::Value::operator[](ArrayIndex): requires arrayValue"); + if (type_ == nullValue) + *this = Value(arrayValue); + CZString key(index); + ObjectValues::iterator it = value_.map_->lower_bound(key); + if (it != value_.map_->end() && (*it).first == key) + return (*it).second; + + ObjectValues::value_type defaultValue(key, nullRef); + it = value_.map_->insert(it, defaultValue); + return (*it).second; +} + +Value& Value::operator[](int index) { + JSON_ASSERT_MESSAGE( + index >= 0, + "in Json::Value::operator[](int index): index cannot be negative"); + return (*this)[ArrayIndex(index)]; +} + +const Value& Value::operator[](ArrayIndex index) const { + JSON_ASSERT_MESSAGE( + type_ == nullValue || type_ == arrayValue, + "in Json::Value::operator[](ArrayIndex)const: requires arrayValue"); + if (type_ == nullValue) + return nullRef; + CZString key(index); + ObjectValues::const_iterator it = value_.map_->find(key); + if (it == value_.map_->end()) + return nullRef; + return (*it).second; +} + +const Value& Value::operator[](int index) const { + JSON_ASSERT_MESSAGE( + index >= 0, + "in Json::Value::operator[](int index) const: index cannot be negative"); + return (*this)[ArrayIndex(index)]; +} + +void Value::initBasic(ValueType vtype, bool allocated) { + type_ = vtype; + allocated_ = allocated; + comments_ = 0; +} + +// Access an object value by name, create a null member if it does not exist. +// @pre Type of '*this' is object or null. +// @param key is null-terminated. +Value& Value::resolveReference(const char* key) { + JSON_ASSERT_MESSAGE( + type_ == nullValue || type_ == objectValue, + "in Json::Value::resolveReference(): requires objectValue"); + if (type_ == nullValue) + *this = Value(objectValue); + CZString actualKey( + key, static_cast(strlen(key)), CZString::noDuplication); // NOTE! + ObjectValues::iterator it = value_.map_->lower_bound(actualKey); + if (it != value_.map_->end() && (*it).first == actualKey) + return (*it).second; + + ObjectValues::value_type defaultValue(actualKey, nullRef); + it = value_.map_->insert(it, defaultValue); + Value& value = (*it).second; + return value; +} + +// @param key is not null-terminated. +Value& Value::resolveReference(char const* key, char const* cend) +{ + JSON_ASSERT_MESSAGE( + type_ == nullValue || type_ == objectValue, + "in Json::Value::resolveReference(key, end): requires objectValue"); + if (type_ == nullValue) + *this = Value(objectValue); + CZString actualKey( + key, static_cast(cend-key), CZString::duplicateOnCopy); + ObjectValues::iterator it = value_.map_->lower_bound(actualKey); + if (it != value_.map_->end() && (*it).first == actualKey) + return (*it).second; + + ObjectValues::value_type defaultValue(actualKey, nullRef); + it = value_.map_->insert(it, defaultValue); + Value& value = (*it).second; + return value; +} + +Value Value::get(ArrayIndex index, const Value& defaultValue) const { + const Value* value = &((*this)[index]); + return value == &nullRef ? defaultValue : *value; +} + +bool Value::isValidIndex(ArrayIndex index) const { return index < size(); } + +Value const* Value::find(char const* key, char const* cend) const +{ + JSON_ASSERT_MESSAGE( + type_ == nullValue || type_ == objectValue, + "in Json::Value::find(key, end, found): requires objectValue or nullValue"); + if (type_ == nullValue) return NULL; + CZString actualKey(key, static_cast(cend-key), CZString::noDuplication); + ObjectValues::const_iterator it = value_.map_->find(actualKey); + if (it == value_.map_->end()) return NULL; + return &(*it).second; +} +const Value& Value::operator[](const char* key) const +{ + Value const* found = find(key, key + strlen(key)); + if (!found) return nullRef; + return *found; +} +Value const& Value::operator[](std::string const& key) const +{ + Value const* found = find(key.data(), key.data() + key.length()); + if (!found) return nullRef; + return *found; +} + +Value& Value::operator[](const char* key) { + return resolveReference(key, key + strlen(key)); +} + +Value& Value::operator[](const std::string& key) { + return resolveReference(key.data(), key.data() + key.length()); +} + +Value& Value::operator[](const StaticString& key) { + return resolveReference(key.c_str()); +} + +#ifdef JSON_USE_CPPTL +Value& Value::operator[](const CppTL::ConstString& key) { + return resolveReference(key.c_str(), key.end_c_str()); +} +Value const& Value::operator[](CppTL::ConstString const& key) const +{ + Value const* found = find(key.c_str(), key.end_c_str()); + if (!found) return nullRef; + return *found; +} +#endif + +Value& Value::append(const Value& value) { return (*this)[size()] = value; } + +Value Value::get(char const* key, char const* cend, Value const& defaultValue) const +{ + Value const* found = find(key, cend); + return !found ? defaultValue : *found; +} +Value Value::get(char const* key, Value const& defaultValue) const +{ + return get(key, key + strlen(key), defaultValue); +} +Value Value::get(std::string const& key, Value const& defaultValue) const +{ + return get(key.data(), key.data() + key.length(), defaultValue); +} + + +bool Value::removeMember(const char* key, const char* cend, Value* removed) +{ + if (type_ != objectValue) { + return false; + } + CZString actualKey(key, static_cast(cend-key), CZString::noDuplication); + ObjectValues::iterator it = value_.map_->find(actualKey); + if (it == value_.map_->end()) + return false; + *removed = it->second; + value_.map_->erase(it); + return true; +} +bool Value::removeMember(const char* key, Value* removed) +{ + return removeMember(key, key + strlen(key), removed); +} +bool Value::removeMember(std::string const& key, Value* removed) +{ + return removeMember(key.data(), key.data() + key.length(), removed); +} +Value Value::removeMember(const char* key) +{ + JSON_ASSERT_MESSAGE(type_ == nullValue || type_ == objectValue, + "in Json::Value::removeMember(): requires objectValue"); + if (type_ == nullValue) + return nullRef; + + Value removed; // null + removeMember(key, key + strlen(key), &removed); + return removed; // still null if removeMember() did nothing +} +Value Value::removeMember(const std::string& key) +{ + return removeMember(key.c_str()); +} + +bool Value::removeIndex(ArrayIndex index, Value* removed) { + if (type_ != arrayValue) { + return false; + } + CZString key(index); + ObjectValues::iterator it = value_.map_->find(key); + if (it == value_.map_->end()) { + return false; + } + *removed = it->second; + ArrayIndex oldSize = size(); + // shift left all items left, into the place of the "removed" + for (ArrayIndex i = index; i < (oldSize - 1); ++i){ + CZString keey(i); + (*value_.map_)[keey] = (*this)[i + 1]; + } + // erase the last one ("leftover") + CZString keyLast(oldSize - 1); + ObjectValues::iterator itLast = value_.map_->find(keyLast); + value_.map_->erase(itLast); + return true; +} + +#ifdef JSON_USE_CPPTL +Value Value::get(const CppTL::ConstString& key, + const Value& defaultValue) const { + return get(key.c_str(), key.end_c_str(), defaultValue); +} +#endif + +bool Value::isMember(char const* key, char const* cend) const +{ + Value const* value = find(key, cend); + return NULL != value; +} +bool Value::isMember(char const* key) const +{ + return isMember(key, key + strlen(key)); +} +bool Value::isMember(std::string const& key) const +{ + return isMember(key.data(), key.data() + key.length()); +} + +#ifdef JSON_USE_CPPTL +bool Value::isMember(const CppTL::ConstString& key) const { + return isMember(key.c_str(), key.end_c_str()); +} +#endif + +Value::Members Value::getMemberNames() const { + JSON_ASSERT_MESSAGE( + type_ == nullValue || type_ == objectValue, + "in Json::Value::getMemberNames(), value must be objectValue"); + if (type_ == nullValue) + return Value::Members(); + Members members; + members.reserve(value_.map_->size()); + ObjectValues::const_iterator it = value_.map_->begin(); + ObjectValues::const_iterator itEnd = value_.map_->end(); + for (; it != itEnd; ++it) { + members.push_back(std::string((*it).first.data(), + (*it).first.length())); + } + return members; +} +// +//# ifdef JSON_USE_CPPTL +// EnumMemberNames +// Value::enumMemberNames() const +//{ +// if ( type_ == objectValue ) +// { +// return CppTL::Enum::any( CppTL::Enum::transform( +// CppTL::Enum::keys( *(value_.map_), CppTL::Type() ), +// MemberNamesTransform() ) ); +// } +// return EnumMemberNames(); +//} +// +// +// EnumValues +// Value::enumValues() const +//{ +// if ( type_ == objectValue || type_ == arrayValue ) +// return CppTL::Enum::anyValues( *(value_.map_), +// CppTL::Type() ); +// return EnumValues(); +//} +// +//# endif + +static bool IsIntegral(double d) { + double integral_part; + return modf(d, &integral_part) == 0.0; +} + +bool Value::isNull() const { return type_ == nullValue; } + +bool Value::isBool() const { return type_ == booleanValue; } + +bool Value::isInt() const { + switch (type_) { + case intValue: + return value_.int_ >= minInt && value_.int_ <= maxInt; + case uintValue: + return value_.uint_ <= UInt(maxInt); + case realValue: + return value_.real_ >= minInt && value_.real_ <= maxInt && + IsIntegral(value_.real_); + default: + break; + } + return false; +} + +bool Value::isUInt() const { + switch (type_) { + case intValue: + return value_.int_ >= 0 && LargestUInt(value_.int_) <= LargestUInt(maxUInt); + case uintValue: + return value_.uint_ <= maxUInt; + case realValue: + return value_.real_ >= 0 && value_.real_ <= maxUInt && + IsIntegral(value_.real_); + default: + break; + } + return false; +} + +bool Value::isInt64() const { +#if defined(JSON_HAS_INT64) + switch (type_) { + case intValue: + return true; + case uintValue: + return value_.uint_ <= UInt64(maxInt64); + case realValue: + // Note that maxInt64 (= 2^63 - 1) is not exactly representable as a + // double, so double(maxInt64) will be rounded up to 2^63. Therefore we + // require the value to be strictly less than the limit. + return value_.real_ >= double(minInt64) && + value_.real_ < double(maxInt64) && IsIntegral(value_.real_); + default: + break; + } +#endif // JSON_HAS_INT64 + return false; +} + +bool Value::isUInt64() const { +#if defined(JSON_HAS_INT64) + switch (type_) { + case intValue: + return value_.int_ >= 0; + case uintValue: + return true; + case realValue: + // Note that maxUInt64 (= 2^64 - 1) is not exactly representable as a + // double, so double(maxUInt64) will be rounded up to 2^64. Therefore we + // require the value to be strictly less than the limit. + return value_.real_ >= 0 && value_.real_ < maxUInt64AsDouble && + IsIntegral(value_.real_); + default: + break; + } +#endif // JSON_HAS_INT64 + return false; +} + +bool Value::isIntegral() const { +#if defined(JSON_HAS_INT64) + return isInt64() || isUInt64(); +#else + return isInt() || isUInt(); +#endif +} + +bool Value::isDouble() const { return type_ == realValue || isIntegral(); } + +bool Value::isNumeric() const { return isIntegral() || isDouble(); } + +bool Value::isString() const { return type_ == stringValue; } + +bool Value::isArray() const { return type_ == arrayValue; } + +bool Value::isObject() const { return type_ == objectValue; } + +void Value::setComment(const char* comment, size_t len, CommentPlacement placement) { + if (!comments_) + comments_ = new CommentInfo[numberOfCommentPlacement]; + if ((len > 0) && (comment[len-1] == '\n')) { + // Always discard trailing newline, to aid indentation. + len -= 1; + } + comments_[placement].setComment(comment, len); +} + +void Value::setComment(const char* comment, CommentPlacement placement) { + setComment(comment, strlen(comment), placement); +} + +void Value::setComment(const std::string& comment, CommentPlacement placement) { + setComment(comment.c_str(), comment.length(), placement); +} + +bool Value::hasComment(CommentPlacement placement) const { + return comments_ != 0 && comments_[placement].comment_ != 0; +} + +std::string Value::getComment(CommentPlacement placement) const { + if (hasComment(placement)) + return comments_[placement].comment_; + return ""; +} + +std::string Value::toStyledString() const { + StyledWriter writer; + return writer.write(*this); +} + +Value::const_iterator Value::begin() const { + switch (type_) { + case arrayValue: + case objectValue: + if (value_.map_) + return const_iterator(value_.map_->begin()); + break; + default: + break; + } + return const_iterator(); +} + +Value::const_iterator Value::end() const { + switch (type_) { + case arrayValue: + case objectValue: + if (value_.map_) + return const_iterator(value_.map_->end()); + break; + default: + break; + } + return const_iterator(); +} + +Value::iterator Value::begin() { + switch (type_) { + case arrayValue: + case objectValue: + if (value_.map_) + return iterator(value_.map_->begin()); + break; + default: + break; + } + return iterator(); +} + +Value::iterator Value::end() { + switch (type_) { + case arrayValue: + case objectValue: + if (value_.map_) + return iterator(value_.map_->end()); + break; + default: + break; + } + return iterator(); +} + +// class PathArgument +// ////////////////////////////////////////////////////////////////// + +PathArgument::PathArgument() : key_(), index_(), kind_(kindNone) {} + +PathArgument::PathArgument(ArrayIndex index) + : key_(), index_(index), kind_(kindIndex) {} + +PathArgument::PathArgument(const char* key) + : key_(key), index_(), kind_(kindKey) {} + +PathArgument::PathArgument(const std::string& key) + : key_(key.c_str()), index_(), kind_(kindKey) {} + +// class Path +// ////////////////////////////////////////////////////////////////// + +Path::Path(const std::string& path, + const PathArgument& a1, + const PathArgument& a2, + const PathArgument& a3, + const PathArgument& a4, + const PathArgument& a5) { + InArgs in; + in.push_back(&a1); + in.push_back(&a2); + in.push_back(&a3); + in.push_back(&a4); + in.push_back(&a5); + makePath(path, in); +} + +void Path::makePath(const std::string& path, const InArgs& in) { + const char* current = path.c_str(); + const char* end = current + path.length(); + InArgs::const_iterator itInArg = in.begin(); + while (current != end) { + if (*current == '[') { + ++current; + if (*current == '%') + addPathInArg(path, in, itInArg, PathArgument::kindIndex); + else { + ArrayIndex index = 0; + for (; current != end && *current >= '0' && *current <= '9'; ++current) + index = index * 10 + ArrayIndex(*current - '0'); + args_.push_back(index); + } + if (current == end || *current++ != ']') + invalidPath(path, int(current - path.c_str())); + } else if (*current == '%') { + addPathInArg(path, in, itInArg, PathArgument::kindKey); + ++current; + } else if (*current == '.') { + ++current; + } else { + const char* beginName = current; + while (current != end && !strchr("[.", *current)) + ++current; + args_.push_back(std::string(beginName, current)); + } + } +} + +void Path::addPathInArg(const std::string& /*path*/, + const InArgs& in, + InArgs::const_iterator& itInArg, + PathArgument::Kind kind) { + if (itInArg == in.end()) { + // Error: missing argument %d + } else if ((*itInArg)->kind_ != kind) { + // Error: bad argument type + } else { + args_.push_back(**itInArg); + } +} + +void Path::invalidPath(const std::string& /*path*/, int /*location*/) { + // Error: invalid path. +} + +const Value& Path::resolve(const Value& root) const { + const Value* node = &root; + for (Args::const_iterator it = args_.begin(); it != args_.end(); ++it) { + const PathArgument& arg = *it; + if (arg.kind_ == PathArgument::kindIndex) { + if (!node->isArray() || !node->isValidIndex(arg.index_)) { + // Error: unable to resolve path (array value expected at position... + } + node = &((*node)[arg.index_]); + } else if (arg.kind_ == PathArgument::kindKey) { + if (!node->isObject()) { + // Error: unable to resolve path (object value expected at position...) + } + node = &((*node)[arg.key_]); + if (node == &Value::nullRef) { + // Error: unable to resolve path (object has no member named '' at + // position...) + } + } + } + return *node; +} + +Value Path::resolve(const Value& root, const Value& defaultValue) const { + const Value* node = &root; + for (Args::const_iterator it = args_.begin(); it != args_.end(); ++it) { + const PathArgument& arg = *it; + if (arg.kind_ == PathArgument::kindIndex) { + if (!node->isArray() || !node->isValidIndex(arg.index_)) + return defaultValue; + node = &((*node)[arg.index_]); + } else if (arg.kind_ == PathArgument::kindKey) { + if (!node->isObject()) + return defaultValue; + node = &((*node)[arg.key_]); + if (node == &Value::nullRef) + return defaultValue; + } + } + return *node; +} + +Value& Path::make(Value& root) const { + Value* node = &root; + for (Args::const_iterator it = args_.begin(); it != args_.end(); ++it) { + const PathArgument& arg = *it; + if (arg.kind_ == PathArgument::kindIndex) { + if (!node->isArray()) { + // Error: node is not an array at position ... + } + node = &((*node)[arg.index_]); + } else if (arg.kind_ == PathArgument::kindKey) { + if (!node->isObject()) { + // Error: node is not an object at position... + } + node = &((*node)[arg.key_]); + } + } + return *node; +} + +} // namespace Json + +// ////////////////////////////////////////////////////////////////////// +// End of content of file: src/lib_json/json_value.cpp +// ////////////////////////////////////////////////////////////////////// + + + + + + +// ////////////////////////////////////////////////////////////////////// +// Beginning of content of file: src/lib_json/json_writer.cpp +// ////////////////////////////////////////////////////////////////////// + +// Copyright 2011 Baptiste Lepilleur +// Distributed under MIT license, or public domain if desired and +// recognized in your jurisdiction. +// See file LICENSE for detail or copy at http://jsoncpp.sourceforge.net/LICENSE + +#if !defined(JSON_IS_AMALGAMATION) +#include +#include "json_tool.h" +#endif // if !defined(JSON_IS_AMALGAMATION) +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(__BORLANDC__) +#include +#endif +#if defined(_MSC_VER) && _MSC_VER >= 1200 && _MSC_VER < 1800 // Between VC++ 6.0 and VC++ 11.0 +#include +#define isfinite _finite +#elif defined(__sun) && defined(__SVR4) //Solaris +#include +#define isfinite finite +#else +#include +#define isfinite std::isfinite +#endif + +#if defined(_MSC_VER) +#if !defined(WINCE) && defined(__STDC_SECURE_LIB__) && _MSC_VER >= 1500 // VC++ 9.0 and above +#define snprintf sprintf_s +#elif _MSC_VER >= 1900 // VC++ 14.0 and above +#define snprintf std::snprintf +#else +#define snprintf _snprintf +#endif +#elif defined(__ANDROID__) +#define snprintf snprintf +#elif __cplusplus >= 201103L +#define snprintf std::snprintf +#endif + +#if defined(__BORLANDC__) +#include +#define isfinite _finite +#define snprintf _snprintf +#endif + +#if defined(_MSC_VER) && _MSC_VER >= 1400 // VC++ 8.0 +// Disable warning about strdup being deprecated. +#pragma warning(disable : 4996) +#endif + +namespace Json { + +#if JSON_HAS_UNIQUE_PTR +typedef std::unique_ptr const StreamWriterPtr; +#else +typedef std::auto_ptr StreamWriterPtr; +#endif + +static bool containsControlCharacter(const char* str) { + while (*str) { + if (isControlCharacter(*(str++))) + return true; + } + return false; +} + +static bool containsControlCharacter0(const char* str, unsigned len) { + char const* end = str + len; + while (end != str) { + if (isControlCharacter(*str) || 0==*str) + return true; + ++str; + } + return false; +} + +std::string valueToString(LargestInt value) { + UIntToStringBuffer buffer; + char* current = buffer + sizeof(buffer); + if (value == Value::minLargestInt) { + uintToString(LargestUInt(Value::maxLargestInt) + 1, current); + *--current = '-'; + } else if (value < 0) { + uintToString(LargestUInt(-value), current); + *--current = '-'; + } else { + uintToString(LargestUInt(value), current); + } + assert(current >= buffer); + return current; +} + +std::string valueToString(LargestUInt value) { + UIntToStringBuffer buffer; + char* current = buffer + sizeof(buffer); + uintToString(value, current); + assert(current >= buffer); + return current; +} + +#if defined(JSON_HAS_INT64) + +std::string valueToString(Int value) { + return valueToString(LargestInt(value)); +} + +std::string valueToString(UInt value) { + return valueToString(LargestUInt(value)); +} + +#endif // # if defined(JSON_HAS_INT64) + +std::string valueToString(double value, bool useSpecialFloats, unsigned int precision) { + // Allocate a buffer that is more than large enough to store the 16 digits of + // precision requested below. + char buffer[32]; + int len = -1; + + char formatString[6]; + sprintf(formatString, "%%.%dg", precision); + + // Print into the buffer. We need not request the alternative representation + // that always has a decimal point because JSON doesn't distingish the + // concepts of reals and integers. + if (isfinite(value)) { + len = snprintf(buffer, sizeof(buffer), formatString, value); + } else { + // IEEE standard states that NaN values will not compare to themselves + if (value != value) { + len = snprintf(buffer, sizeof(buffer), useSpecialFloats ? "NaN" : "null"); + } else if (value < 0) { + len = snprintf(buffer, sizeof(buffer), useSpecialFloats ? "-Infinity" : "-1e+9999"); + } else { + len = snprintf(buffer, sizeof(buffer), useSpecialFloats ? "Infinity" : "1e+9999"); + } + // For those, we do not need to call fixNumLoc, but it is fast. + } + assert(len >= 0); + fixNumericLocale(buffer, buffer + len); + return buffer; +} + +std::string valueToString(double value) { return valueToString(value, false, 17); } + +std::string valueToString(bool value) { return value ? "true" : "false"; } + +std::string valueToQuotedString(const char* value) { + if (value == NULL) + return ""; + // Not sure how to handle unicode... + if (strpbrk(value, "\"\\\b\f\n\r\t") == NULL && + !containsControlCharacter(value)) + return std::string("\"") + value + "\""; + // We have to walk value and escape any special characters. + // Appending to std::string is not efficient, but this should be rare. + // (Note: forward slashes are *not* rare, but I am not escaping them.) + std::string::size_type maxsize = + strlen(value) * 2 + 3; // allescaped+quotes+NULL + std::string result; + result.reserve(maxsize); // to avoid lots of mallocs + result += "\""; + for (const char* c = value; *c != 0; ++c) { + switch (*c) { + case '\"': + result += "\\\""; + break; + case '\\': + result += "\\\\"; + break; + case '\b': + result += "\\b"; + break; + case '\f': + result += "\\f"; + break; + case '\n': + result += "\\n"; + break; + case '\r': + result += "\\r"; + break; + case '\t': + result += "\\t"; + break; + // case '/': + // Even though \/ is considered a legal escape in JSON, a bare + // slash is also legal, so I see no reason to escape it. + // (I hope I am not misunderstanding something. + // blep notes: actually escaping \/ may be useful in javascript to avoid (*c); + result += oss.str(); + } else { + result += *c; + } + break; + } + } + result += "\""; + return result; +} + +// https://github.com/upcaste/upcaste/blob/master/src/upcore/src/cstring/strnpbrk.cpp +static char const* strnpbrk(char const* s, char const* accept, size_t n) { + assert((s || !n) && accept); + + char const* const end = s + n; + for (char const* cur = s; cur < end; ++cur) { + int const c = *cur; + for (char const* a = accept; *a; ++a) { + if (*a == c) { + return cur; + } + } + } + return NULL; +} +static std::string valueToQuotedStringN(const char* value, unsigned length) { + if (value == NULL) + return ""; + // Not sure how to handle unicode... + if (strnpbrk(value, "\"\\\b\f\n\r\t", length) == NULL && + !containsControlCharacter0(value, length)) + return std::string("\"") + value + "\""; + // We have to walk value and escape any special characters. + // Appending to std::string is not efficient, but this should be rare. + // (Note: forward slashes are *not* rare, but I am not escaping them.) + std::string::size_type maxsize = + length * 2 + 3; // allescaped+quotes+NULL + std::string result; + result.reserve(maxsize); // to avoid lots of mallocs + result += "\""; + char const* end = value + length; + for (const char* c = value; c != end; ++c) { + switch (*c) { + case '\"': + result += "\\\""; + break; + case '\\': + result += "\\\\"; + break; + case '\b': + result += "\\b"; + break; + case '\f': + result += "\\f"; + break; + case '\n': + result += "\\n"; + break; + case '\r': + result += "\\r"; + break; + case '\t': + result += "\\t"; + break; + // case '/': + // Even though \/ is considered a legal escape in JSON, a bare + // slash is also legal, so I see no reason to escape it. + // (I hope I am not misunderstanding something.) + // blep notes: actually escaping \/ may be useful in javascript to avoid (*c); + result += oss.str(); + } else { + result += *c; + } + break; + } + } + result += "\""; + return result; +} + +// Class Writer +// ////////////////////////////////////////////////////////////////// +Writer::~Writer() {} + +// Class FastWriter +// ////////////////////////////////////////////////////////////////// + +FastWriter::FastWriter() + : yamlCompatiblityEnabled_(false) {} + +void FastWriter::enableYAMLCompatibility() { yamlCompatiblityEnabled_ = true; } + +std::string FastWriter::write(const Value& root) { + document_ = ""; + writeValue(root); + document_ += "\n"; + return document_; +} + +void FastWriter::writeValue(const Value& value) { + switch (value.type()) { + case nullValue: + document_ += "null"; + break; + case intValue: + document_ += valueToString(value.asLargestInt()); + break; + case uintValue: + document_ += valueToString(value.asLargestUInt()); + break; + case realValue: + document_ += valueToString(value.asDouble()); + break; + case stringValue: + { + // Is NULL possible for value.string_? + char const* str; + char const* end; + bool ok = value.getString(&str, &end); + if (ok) document_ += valueToQuotedStringN(str, static_cast(end-str)); + break; + } + case booleanValue: + document_ += valueToString(value.asBool()); + break; + case arrayValue: { + document_ += '['; + int size = value.size(); + for (int index = 0; index < size; ++index) { + if (index > 0) + document_ += ','; + writeValue(value[index]); + } + document_ += ']'; + } break; + case objectValue: { + Value::Members members(value.getMemberNames()); + document_ += '{'; + for (Value::Members::iterator it = members.begin(); it != members.end(); + ++it) { + const std::string& name = *it; + if (it != members.begin()) + document_ += ','; + document_ += valueToQuotedStringN(name.data(), static_cast(name.length())); + document_ += yamlCompatiblityEnabled_ ? ": " : ":"; + writeValue(value[name]); + } + document_ += '}'; + } break; + } +} + +// Class StyledWriter +// ////////////////////////////////////////////////////////////////// + +StyledWriter::StyledWriter() + : rightMargin_(74), indentSize_(3), addChildValues_() {} + +std::string StyledWriter::write(const Value& root) { + document_ = ""; + addChildValues_ = false; + indentString_ = ""; + writeCommentBeforeValue(root); + writeValue(root); + writeCommentAfterValueOnSameLine(root); + document_ += "\n"; + return document_; +} + +void StyledWriter::writeValue(const Value& value) { + switch (value.type()) { + case nullValue: + pushValue("null"); + break; + case intValue: + pushValue(valueToString(value.asLargestInt())); + break; + case uintValue: + pushValue(valueToString(value.asLargestUInt())); + break; + case realValue: + pushValue(valueToString(value.asDouble())); + break; + case stringValue: + { + // Is NULL possible for value.string_? + char const* str; + char const* end; + bool ok = value.getString(&str, &end); + if (ok) pushValue(valueToQuotedStringN(str, static_cast(end-str))); + else pushValue(""); + break; + } + case booleanValue: + pushValue(valueToString(value.asBool())); + break; + case arrayValue: + writeArrayValue(value); + break; + case objectValue: { + Value::Members members(value.getMemberNames()); + if (members.empty()) + pushValue("{}"); + else { + writeWithIndent("{"); + indent(); + Value::Members::iterator it = members.begin(); + for (;;) { + const std::string& name = *it; + const Value& childValue = value[name]; + writeCommentBeforeValue(childValue); + writeWithIndent(valueToQuotedString(name.c_str())); + document_ += " : "; + writeValue(childValue); + if (++it == members.end()) { + writeCommentAfterValueOnSameLine(childValue); + break; + } + document_ += ','; + writeCommentAfterValueOnSameLine(childValue); + } + unindent(); + writeWithIndent("}"); + } + } break; + } +} + +void StyledWriter::writeArrayValue(const Value& value) { + unsigned size = value.size(); + if (size == 0) + pushValue("[]"); + else { + bool isArrayMultiLine = isMultineArray(value); + if (isArrayMultiLine) { + writeWithIndent("["); + indent(); + bool hasChildValue = !childValues_.empty(); + unsigned index = 0; + for (;;) { + const Value& childValue = value[index]; + writeCommentBeforeValue(childValue); + if (hasChildValue) + writeWithIndent(childValues_[index]); + else { + writeIndent(); + writeValue(childValue); + } + if (++index == size) { + writeCommentAfterValueOnSameLine(childValue); + break; + } + document_ += ','; + writeCommentAfterValueOnSameLine(childValue); + } + unindent(); + writeWithIndent("]"); + } else // output on a single line + { + assert(childValues_.size() == size); + document_ += "[ "; + for (unsigned index = 0; index < size; ++index) { + if (index > 0) + document_ += ", "; + document_ += childValues_[index]; + } + document_ += " ]"; + } + } +} + +bool StyledWriter::isMultineArray(const Value& value) { + int size = value.size(); + bool isMultiLine = size * 3 >= rightMargin_; + childValues_.clear(); + for (int index = 0; index < size && !isMultiLine; ++index) { + const Value& childValue = value[index]; + isMultiLine = + isMultiLine || ((childValue.isArray() || childValue.isObject()) && + childValue.size() > 0); + } + if (!isMultiLine) // check if line length > max line length + { + childValues_.reserve(size); + addChildValues_ = true; + int lineLength = 4 + (size - 1) * 2; // '[ ' + ', '*n + ' ]' + for (int index = 0; index < size; ++index) { + if (hasCommentForValue(value[index])) { + isMultiLine = true; + } + writeValue(value[index]); + lineLength += int(childValues_[index].length()); + } + addChildValues_ = false; + isMultiLine = isMultiLine || lineLength >= rightMargin_; + } + return isMultiLine; +} + +void StyledWriter::pushValue(const std::string& value) { + if (addChildValues_) + childValues_.push_back(value); + else + document_ += value; +} + +void StyledWriter::writeIndent() { + if (!document_.empty()) { + char last = document_[document_.length() - 1]; + if (last == ' ') // already indented + return; + if (last != '\n') // Comments may add new-line + document_ += '\n'; + } + document_ += indentString_; +} + +void StyledWriter::writeWithIndent(const std::string& value) { + writeIndent(); + document_ += value; +} + +void StyledWriter::indent() { indentString_ += std::string(indentSize_, ' '); } + +void StyledWriter::unindent() { + assert(int(indentString_.size()) >= indentSize_); + indentString_.resize(indentString_.size() - indentSize_); +} + +void StyledWriter::writeCommentBeforeValue(const Value& root) { + if (!root.hasComment(commentBefore)) + return; + + document_ += "\n"; + writeIndent(); + const std::string& comment = root.getComment(commentBefore); + std::string::const_iterator iter = comment.begin(); + while (iter != comment.end()) { + document_ += *iter; + if (*iter == '\n' && + (iter != comment.end() && *(iter + 1) == '/')) + writeIndent(); + ++iter; + } + + // Comments are stripped of trailing newlines, so add one here + document_ += "\n"; +} + +void StyledWriter::writeCommentAfterValueOnSameLine(const Value& root) { + if (root.hasComment(commentAfterOnSameLine)) + document_ += " " + root.getComment(commentAfterOnSameLine); + + if (root.hasComment(commentAfter)) { + document_ += "\n"; + document_ += root.getComment(commentAfter); + document_ += "\n"; + } +} + +bool StyledWriter::hasCommentForValue(const Value& value) { + return value.hasComment(commentBefore) || + value.hasComment(commentAfterOnSameLine) || + value.hasComment(commentAfter); +} + +// Class StyledStreamWriter +// ////////////////////////////////////////////////////////////////// + +StyledStreamWriter::StyledStreamWriter(std::string indentation) + : document_(NULL), rightMargin_(74), indentation_(indentation), + addChildValues_() {} + +void StyledStreamWriter::write(std::ostream& out, const Value& root) { + document_ = &out; + addChildValues_ = false; + indentString_ = ""; + indented_ = true; + writeCommentBeforeValue(root); + if (!indented_) writeIndent(); + indented_ = true; + writeValue(root); + writeCommentAfterValueOnSameLine(root); + *document_ << "\n"; + document_ = NULL; // Forget the stream, for safety. +} + +void StyledStreamWriter::writeValue(const Value& value) { + switch (value.type()) { + case nullValue: + pushValue("null"); + break; + case intValue: + pushValue(valueToString(value.asLargestInt())); + break; + case uintValue: + pushValue(valueToString(value.asLargestUInt())); + break; + case realValue: + pushValue(valueToString(value.asDouble())); + break; + case stringValue: + { + // Is NULL possible for value.string_? + char const* str; + char const* end; + bool ok = value.getString(&str, &end); + if (ok) pushValue(valueToQuotedStringN(str, static_cast(end-str))); + else pushValue(""); + break; + } + case booleanValue: + pushValue(valueToString(value.asBool())); + break; + case arrayValue: + writeArrayValue(value); + break; + case objectValue: { + Value::Members members(value.getMemberNames()); + if (members.empty()) + pushValue("{}"); + else { + writeWithIndent("{"); + indent(); + Value::Members::iterator it = members.begin(); + for (;;) { + const std::string& name = *it; + const Value& childValue = value[name]; + writeCommentBeforeValue(childValue); + writeWithIndent(valueToQuotedString(name.c_str())); + *document_ << " : "; + writeValue(childValue); + if (++it == members.end()) { + writeCommentAfterValueOnSameLine(childValue); + break; + } + *document_ << ","; + writeCommentAfterValueOnSameLine(childValue); + } + unindent(); + writeWithIndent("}"); + } + } break; + } +} + +void StyledStreamWriter::writeArrayValue(const Value& value) { + unsigned size = value.size(); + if (size == 0) + pushValue("[]"); + else { + bool isArrayMultiLine = isMultineArray(value); + if (isArrayMultiLine) { + writeWithIndent("["); + indent(); + bool hasChildValue = !childValues_.empty(); + unsigned index = 0; + for (;;) { + const Value& childValue = value[index]; + writeCommentBeforeValue(childValue); + if (hasChildValue) + writeWithIndent(childValues_[index]); + else { + if (!indented_) writeIndent(); + indented_ = true; + writeValue(childValue); + indented_ = false; + } + if (++index == size) { + writeCommentAfterValueOnSameLine(childValue); + break; + } + *document_ << ","; + writeCommentAfterValueOnSameLine(childValue); + } + unindent(); + writeWithIndent("]"); + } else // output on a single line + { + assert(childValues_.size() == size); + *document_ << "[ "; + for (unsigned index = 0; index < size; ++index) { + if (index > 0) + *document_ << ", "; + *document_ << childValues_[index]; + } + *document_ << " ]"; + } + } +} + +bool StyledStreamWriter::isMultineArray(const Value& value) { + int size = value.size(); + bool isMultiLine = size * 3 >= rightMargin_; + childValues_.clear(); + for (int index = 0; index < size && !isMultiLine; ++index) { + const Value& childValue = value[index]; + isMultiLine = + isMultiLine || ((childValue.isArray() || childValue.isObject()) && + childValue.size() > 0); + } + if (!isMultiLine) // check if line length > max line length + { + childValues_.reserve(size); + addChildValues_ = true; + int lineLength = 4 + (size - 1) * 2; // '[ ' + ', '*n + ' ]' + for (int index = 0; index < size; ++index) { + if (hasCommentForValue(value[index])) { + isMultiLine = true; + } + writeValue(value[index]); + lineLength += int(childValues_[index].length()); + } + addChildValues_ = false; + isMultiLine = isMultiLine || lineLength >= rightMargin_; + } + return isMultiLine; +} + +void StyledStreamWriter::pushValue(const std::string& value) { + if (addChildValues_) + childValues_.push_back(value); + else + *document_ << value; +} + +void StyledStreamWriter::writeIndent() { + // blep intended this to look at the so-far-written string + // to determine whether we are already indented, but + // with a stream we cannot do that. So we rely on some saved state. + // The caller checks indented_. + *document_ << '\n' << indentString_; +} + +void StyledStreamWriter::writeWithIndent(const std::string& value) { + if (!indented_) writeIndent(); + *document_ << value; + indented_ = false; +} + +void StyledStreamWriter::indent() { indentString_ += indentation_; } + +void StyledStreamWriter::unindent() { + assert(indentString_.size() >= indentation_.size()); + indentString_.resize(indentString_.size() - indentation_.size()); +} + +void StyledStreamWriter::writeCommentBeforeValue(const Value& root) { + if (!root.hasComment(commentBefore)) + return; + + if (!indented_) writeIndent(); + const std::string& comment = root.getComment(commentBefore); + std::string::const_iterator iter = comment.begin(); + while (iter != comment.end()) { + *document_ << *iter; + if (*iter == '\n' && + (iter != comment.end() && *(iter + 1) == '/')) + // writeIndent(); // would include newline + *document_ << indentString_; + ++iter; + } + indented_ = false; +} + +void StyledStreamWriter::writeCommentAfterValueOnSameLine(const Value& root) { + if (root.hasComment(commentAfterOnSameLine)) + *document_ << ' ' << root.getComment(commentAfterOnSameLine); + + if (root.hasComment(commentAfter)) { + writeIndent(); + *document_ << root.getComment(commentAfter); + } + indented_ = false; +} + +bool StyledStreamWriter::hasCommentForValue(const Value& value) { + return value.hasComment(commentBefore) || + value.hasComment(commentAfterOnSameLine) || + value.hasComment(commentAfter); +} + +////////////////////////// +// BuiltStyledStreamWriter + +/// Scoped enums are not available until C++11. +struct CommentStyle { + /// Decide whether to write comments. + enum Enum { + None, ///< Drop all comments. + Most, ///< Recover odd behavior of previous versions (not implemented yet). + All ///< Keep all comments. + }; +}; + +struct BuiltStyledStreamWriter : public StreamWriter +{ + BuiltStyledStreamWriter( + std::string const& indentation, + CommentStyle::Enum cs, + std::string const& colonSymbol, + std::string const& nullSymbol, + std::string const& endingLineFeedSymbol, + bool useSpecialFloats, + unsigned int precision); + virtual int write(Value const& root, std::ostream* sout); +private: + void writeValue(Value const& value); + void writeArrayValue(Value const& value); + bool isMultineArray(Value const& value); + void pushValue(std::string const& value); + void writeIndent(); + void writeWithIndent(std::string const& value); + void indent(); + void unindent(); + void writeCommentBeforeValue(Value const& root); + void writeCommentAfterValueOnSameLine(Value const& root); + static bool hasCommentForValue(const Value& value); + + typedef std::vector ChildValues; + + ChildValues childValues_; + std::string indentString_; + int rightMargin_; + std::string indentation_; + CommentStyle::Enum cs_; + std::string colonSymbol_; + std::string nullSymbol_; + std::string endingLineFeedSymbol_; + bool addChildValues_ : 1; + bool indented_ : 1; + bool useSpecialFloats_ : 1; + unsigned int precision_; +}; +BuiltStyledStreamWriter::BuiltStyledStreamWriter( + std::string const& indentation, + CommentStyle::Enum cs, + std::string const& colonSymbol, + std::string const& nullSymbol, + std::string const& endingLineFeedSymbol, + bool useSpecialFloats, + unsigned int precision) + : rightMargin_(74) + , indentation_(indentation) + , cs_(cs) + , colonSymbol_(colonSymbol) + , nullSymbol_(nullSymbol) + , endingLineFeedSymbol_(endingLineFeedSymbol) + , addChildValues_(false) + , indented_(false) + , useSpecialFloats_(useSpecialFloats) + , precision_(precision) +{ +} +int BuiltStyledStreamWriter::write(Value const& root, std::ostream* sout) +{ + sout_ = sout; + addChildValues_ = false; + indented_ = true; + indentString_ = ""; + writeCommentBeforeValue(root); + if (!indented_) writeIndent(); + indented_ = true; + writeValue(root); + writeCommentAfterValueOnSameLine(root); + *sout_ << endingLineFeedSymbol_; + sout_ = NULL; + return 0; +} +void BuiltStyledStreamWriter::writeValue(Value const& value) { + switch (value.type()) { + case nullValue: + pushValue(nullSymbol_); + break; + case intValue: + pushValue(valueToString(value.asLargestInt())); + break; + case uintValue: + pushValue(valueToString(value.asLargestUInt())); + break; + case realValue: + pushValue(valueToString(value.asDouble(), useSpecialFloats_, precision_)); + break; + case stringValue: + { + // Is NULL is possible for value.string_? + char const* str; + char const* end; + bool ok = value.getString(&str, &end); + if (ok) pushValue(valueToQuotedStringN(str, static_cast(end-str))); + else pushValue(""); + break; + } + case booleanValue: + pushValue(valueToString(value.asBool())); + break; + case arrayValue: + writeArrayValue(value); + break; + case objectValue: { + Value::Members members(value.getMemberNames()); + if (members.empty()) + pushValue("{}"); + else { + writeWithIndent("{"); + indent(); + Value::Members::iterator it = members.begin(); + for (;;) { + std::string const& name = *it; + Value const& childValue = value[name]; + writeCommentBeforeValue(childValue); + writeWithIndent(valueToQuotedStringN(name.data(), static_cast(name.length()))); + *sout_ << colonSymbol_; + writeValue(childValue); + if (++it == members.end()) { + writeCommentAfterValueOnSameLine(childValue); + break; + } + *sout_ << ","; + writeCommentAfterValueOnSameLine(childValue); + } + unindent(); + writeWithIndent("}"); + } + } break; + } +} + +void BuiltStyledStreamWriter::writeArrayValue(Value const& value) { + unsigned size = value.size(); + if (size == 0) + pushValue("[]"); + else { + bool isMultiLine = (cs_ == CommentStyle::All) || isMultineArray(value); + if (isMultiLine) { + writeWithIndent("["); + indent(); + bool hasChildValue = !childValues_.empty(); + unsigned index = 0; + for (;;) { + Value const& childValue = value[index]; + writeCommentBeforeValue(childValue); + if (hasChildValue) + writeWithIndent(childValues_[index]); + else { + if (!indented_) writeIndent(); + indented_ = true; + writeValue(childValue); + indented_ = false; + } + if (++index == size) { + writeCommentAfterValueOnSameLine(childValue); + break; + } + *sout_ << ","; + writeCommentAfterValueOnSameLine(childValue); + } + unindent(); + writeWithIndent("]"); + } else // output on a single line + { + assert(childValues_.size() == size); + *sout_ << "["; + if (!indentation_.empty()) *sout_ << " "; + for (unsigned index = 0; index < size; ++index) { + if (index > 0) + *sout_ << ", "; + *sout_ << childValues_[index]; + } + if (!indentation_.empty()) *sout_ << " "; + *sout_ << "]"; + } + } +} + +bool BuiltStyledStreamWriter::isMultineArray(Value const& value) { + int size = value.size(); + bool isMultiLine = size * 3 >= rightMargin_; + childValues_.clear(); + for (int index = 0; index < size && !isMultiLine; ++index) { + Value const& childValue = value[index]; + isMultiLine = + isMultiLine || ((childValue.isArray() || childValue.isObject()) && + childValue.size() > 0); + } + if (!isMultiLine) // check if line length > max line length + { + childValues_.reserve(size); + addChildValues_ = true; + int lineLength = 4 + (size - 1) * 2; // '[ ' + ', '*n + ' ]' + for (int index = 0; index < size; ++index) { + if (hasCommentForValue(value[index])) { + isMultiLine = true; + } + writeValue(value[index]); + lineLength += int(childValues_[index].length()); + } + addChildValues_ = false; + isMultiLine = isMultiLine || lineLength >= rightMargin_; + } + return isMultiLine; +} + +void BuiltStyledStreamWriter::pushValue(std::string const& value) { + if (addChildValues_) + childValues_.push_back(value); + else + *sout_ << value; +} + +void BuiltStyledStreamWriter::writeIndent() { + // blep intended this to look at the so-far-written string + // to determine whether we are already indented, but + // with a stream we cannot do that. So we rely on some saved state. + // The caller checks indented_. + + if (!indentation_.empty()) { + // In this case, drop newlines too. + *sout_ << '\n' << indentString_; + } +} + +void BuiltStyledStreamWriter::writeWithIndent(std::string const& value) { + if (!indented_) writeIndent(); + *sout_ << value; + indented_ = false; +} + +void BuiltStyledStreamWriter::indent() { indentString_ += indentation_; } + +void BuiltStyledStreamWriter::unindent() { + assert(indentString_.size() >= indentation_.size()); + indentString_.resize(indentString_.size() - indentation_.size()); +} + +void BuiltStyledStreamWriter::writeCommentBeforeValue(Value const& root) { + if (cs_ == CommentStyle::None) return; + if (!root.hasComment(commentBefore)) + return; + + if (!indented_) writeIndent(); + const std::string& comment = root.getComment(commentBefore); + std::string::const_iterator iter = comment.begin(); + while (iter != comment.end()) { + *sout_ << *iter; + if (*iter == '\n' && + (iter != comment.end() && *(iter + 1) == '/')) + // writeIndent(); // would write extra newline + *sout_ << indentString_; + ++iter; + } + indented_ = false; +} + +void BuiltStyledStreamWriter::writeCommentAfterValueOnSameLine(Value const& root) { + if (cs_ == CommentStyle::None) return; + if (root.hasComment(commentAfterOnSameLine)) + *sout_ << " " + root.getComment(commentAfterOnSameLine); + + if (root.hasComment(commentAfter)) { + writeIndent(); + *sout_ << root.getComment(commentAfter); + } +} + +// static +bool BuiltStyledStreamWriter::hasCommentForValue(const Value& value) { + return value.hasComment(commentBefore) || + value.hasComment(commentAfterOnSameLine) || + value.hasComment(commentAfter); +} + +/////////////// +// StreamWriter + +StreamWriter::StreamWriter() + : sout_(NULL) +{ +} +StreamWriter::~StreamWriter() +{ +} +StreamWriter::Factory::~Factory() +{} +StreamWriterBuilder::StreamWriterBuilder() +{ + setDefaults(&settings_); +} +StreamWriterBuilder::~StreamWriterBuilder() +{} +StreamWriter* StreamWriterBuilder::newStreamWriter() const +{ + std::string indentation = settings_["indentation"].asString(); + std::string cs_str = settings_["commentStyle"].asString(); + bool eyc = settings_["enableYAMLCompatibility"].asBool(); + bool dnp = settings_["dropNullPlaceholders"].asBool(); + bool usf = settings_["useSpecialFloats"].asBool(); + unsigned int pre = settings_["precision"].asUInt(); + CommentStyle::Enum cs = CommentStyle::All; + if (cs_str == "All") { + cs = CommentStyle::All; + } else if (cs_str == "None") { + cs = CommentStyle::None; + } else { + throwRuntimeError("commentStyle must be 'All' or 'None'"); + } + std::string colonSymbol = " : "; + if (eyc) { + colonSymbol = ": "; + } else if (indentation.empty()) { + colonSymbol = ":"; + } + std::string nullSymbol = "null"; + if (dnp) { + nullSymbol = ""; + } + if (pre > 17) pre = 17; + std::string endingLineFeedSymbol = ""; + return new BuiltStyledStreamWriter( + indentation, cs, + colonSymbol, nullSymbol, endingLineFeedSymbol, usf, pre); +} +static void getValidWriterKeys(std::set* valid_keys) +{ + valid_keys->clear(); + valid_keys->insert("indentation"); + valid_keys->insert("commentStyle"); + valid_keys->insert("enableYAMLCompatibility"); + valid_keys->insert("dropNullPlaceholders"); + valid_keys->insert("useSpecialFloats"); + valid_keys->insert("precision"); +} +bool StreamWriterBuilder::validate(Json::Value* invalid) const +{ + Json::Value my_invalid; + if (!invalid) invalid = &my_invalid; // so we do not need to test for NULL + Json::Value& inv = *invalid; + std::set valid_keys; + getValidWriterKeys(&valid_keys); + Value::Members keys = settings_.getMemberNames(); + size_t n = keys.size(); + for (size_t i = 0; i < n; ++i) { + std::string const& key = keys[i]; + if (valid_keys.find(key) == valid_keys.end()) { + inv[key] = settings_[key]; + } + } + return 0u == inv.size(); +} +Value& StreamWriterBuilder::operator[](std::string key) +{ + return settings_[key]; +} +// static +void StreamWriterBuilder::setDefaults(Json::Value* settings) +{ + //! [StreamWriterBuilderDefaults] + (*settings)["commentStyle"] = "All"; + (*settings)["indentation"] = "\t"; + (*settings)["enableYAMLCompatibility"] = false; + (*settings)["dropNullPlaceholders"] = false; + (*settings)["useSpecialFloats"] = false; + (*settings)["precision"] = 17; + //! [StreamWriterBuilderDefaults] +} + +std::string writeString(StreamWriter::Factory const& builder, Value const& root) { + std::ostringstream sout; + StreamWriterPtr const writer(builder.newStreamWriter()); + writer->write(root, &sout); + return sout.str(); +} + +std::ostream& operator<<(std::ostream& sout, Value const& root) { + StreamWriterBuilder builder; + StreamWriterPtr const writer(builder.newStreamWriter()); + writer->write(root, &sout); + return sout; +} + +} // namespace Json + +// ////////////////////////////////////////////////////////////////////// +// End of content of file: src/lib_json/json_writer.cpp +// ////////////////////////////////////////////////////////////////////// + + + + diff --git a/tools/onnx-subgraph/src/lib/partition.cpp b/tools/onnx-subgraph/src/lib/partition.cpp new file mode 100644 index 00000000000..c480be9923c --- /dev/null +++ b/tools/onnx-subgraph/src/lib/partition.cpp @@ -0,0 +1,2727 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "partition.h" +#include +#include +#include +#define MAX_DEPTH 1000 +std::vector Subgraphs; +/** + * Prints the subgraph information of an ONNX model to specified files. + * + * @param Subgraphs A vector containing subgraph information. + * @param subgraph_file_name The filename for the output of subgraph information. + * @param otherSubgraphs A vector containing other subgraph information. + * @param other_subgraph_file_name The filename for the output of other subgraph information. + */ +void print_subgraphs(std::vector Subgraphs, char* subgraph_file_name, std::vector otherSubgraphs, char* other_subgraph_file_name) +{ + int node_sum = 0; + std::ofstream outFile(subgraph_file_name); + if (!outFile.is_open()) { + std::cerr << "Error opening file." << std::endl; + exit(0); + } + int id = 0; + for (const auto& vec : Subgraphs) { + outFile << " subgraph" << id << ":"; + for (const auto& node : vec.node()) { + outFile << node.name() << " "; + } + id++; + outFile << std::endl; + node_sum += vec.node_size(); + } + std::ofstream outFile_2(other_subgraph_file_name); + if (!outFile_2.is_open()) { + std::cerr << "Error opening file." << std::endl; + exit(0); + } + std::cout << "before:" << std::endl; + for (const auto& vec : otherSubgraphs) { + outFile_2 << " subgraph" << id << ":"; + for (const auto& node : vec.node()) { + outFile_2 << node.name() << " "; + } + id++; + outFile_2 << std::endl; + node_sum += vec.node_size(); + } +} +/////// +/** + * @brief Constructs an adjacency list representation of the ONNX graph. + * + * @param [in] g A const reference to an ONNX GraphProto object that contains the graph structure. + * @param [in,out] visited A pointer to an integer array used to mark whether nodes have been visited. + * @pre The 'visited' array should be pre-allocated with a size at least equal to the number of nodes in the graph. + * @post The 'visited' array will be initialized to 0 for all nodes. + * @exception None + * @return A vector of graph_adjacency_node objects representing the adjacency list of the graph. + */ +std::vector get_adjancency_list(const onnx::GraphProto &g, int* visited) +{ + std::vector adjacency_list; + int node_index=0; + for(const auto& node : g.node()) + { + visited[node_index]=0; + graph_adjacency_node ad_node; + ad_node.index = node_index; + ad_node.name = node.name(); + const auto& outputs = node.output(); + for(const auto& output : outputs) + { + int output_node_index=0; + for(const auto& output_node : g.node()) + { + int find_flag=0; + const auto& inputs = output_node.input(); + for(const auto& input : inputs) + { + if(output==input) + { + find_flag=1; + break; + } + } + if(find_flag==1) + { + if(std::find(ad_node.output_node_index.begin(),ad_node.output_node_index.end(),output_node_index)==ad_node.output_node_index.end()) + { + ad_node.output_node_index.push_back(output_node_index); + } + } + output_node_index++; + } + } + node_index++; + adjacency_list.push_back(ad_node); + } + return adjacency_list; +} +/** + * @brief Calculates the size of a specific node in the ONNX graph in kilobytes (KB). + * + * @param [in] g A const reference to an ONNX GraphProto object that contains the graph structure. + * @param [in] node_index The index of the node for which the size is to be calculated. + * @pre The node_index should be a valid index within the range of nodes in the graph. + * @post None + * @exception None + * @return The size of the node in kilobytes (KB). + */ +float calculate_node_size(const onnx::GraphProto &g, int node_index)//unit : KB +{ + int64_t node_size = 0; + for(int i = 0; i < g.node(node_index).input_size(); i++) + { + std::string input_name = g.node(node_index).input(i); + for(int j = 0; j < g.initializer_size(); j++) + { + if(g.initializer(j).name() == input_name) + { + int64_t node_init_size = 4; + for(int k = 0; k < g.initializer(j).dims().size(); k ++) + { + node_init_size = g.initializer(j).dims(k) * node_init_size; + } + node_size += node_init_size; + break; + } + } + } + return float(node_size*1.0/1024.0); +} +/** + * @brief Depth-First Search (DFS) to build a NPU subgraph. + * + * @param [in] onnx_graph Input ONNX graph structure. + * @param [out] onnx_subgraph Output subgraph. + * @param [in,out] subgraph_node_indices Vector storing indices of nodes in the subgraph. + * @param [in,out] visited Array recording whether nodes have been visited. + * @param [in] start_node Current starting node for the search. + * @param [in] current_node_index Index of the current node. + * @param [in] adjacency_list Adjacency list representing connections between nodes in the graph. + * @param [in] supported_op_types List of supported operation types. + * @param [in] preferred_op_types List of preferred operation types (not used in the code). + * @param [in] current_depth Current depth of the search. + * @param [in,out] current_graph_size Current size of the subgraph. + * @param [in] max_graph_size Maximum allowed size of the subgraph. + * @pre `current_node_index` should be a valid node index. + * @post If the subgraph size exceeds `max_graph_size`, a warning message is printed. + * @exception None + */ +void DFS(const onnx::GraphProto &g,onnx::GraphProto &subgraph, std::vector &sugraph_node_index, + int* visited, const onnx::NodeProto& start_node, + int node_index,std::vector& adjacency_list, + const std::vector& support_op, + const std::vector& prefer_op, + int depth_in, + float& graph_size, + float max_graph_size) +{ + int depth_out = depth_in + 1; + *subgraph.add_node()=start_node; + visited[node_index]=1; + sugraph_node_index.push_back(node_index); + float node_size = calculate_node_size(g, node_index); + graph_size += node_size; + if(graph_size>max_graph_size) + { + std::cout<<"graph size exceed max size!"<1) + { + std::cout<"; + } + // + int next_node_index=adjacency_list[node_index].output_node_index[i]; + const auto & next_node=g.node(next_node_index); + if(!visited[next_node_index]&&(std::find(support_op.begin(), support_op.end(), next_node.op_type()) != support_op.end())&&(depth_out < MAX_DEPTH)&&(graph_size < max_graph_size)) //尚未访问且op_type符合的邻接顶点 + DFS(g,subgraph,sugraph_node_index,visited,next_node,next_node_index,adjacency_list,support_op, prefer_op, depth_out, graph_size, max_graph_size); + } +} +/** +* @brief Perform a depth-first search (DFS) to build a CPU subgraph from a given starting node. +* +* @param [in] g The original ONNX graph from which the subgraph will be extracted. +* @param [out] subgraph The subgraph being constructed. +* @param [out] subgraph_node_indices A vector to store indices of nodes included in the subgraph. +* @param [in,out] visited An array to keep track of visited nodes. +* @param [in] start_node The starting node for the DFS. +* @param [in] node_index The index of the starting node in the original graph. +* @param [in] adjacency_list The adjacency list representing the graph's structure. +* @param [in] depth_in The current depth of the DFS. +* @param [in,out] graph_size The cumulative size of the nodes in the subgraph. +* @param [in] max_graph_size The maximum allowed size for the subgraph. +* +* @pre The graph `g` and `adjacency_list` should be properly initialized. +* @pre The `visited` array should be initialized to zero. +* @pre `graph_size` should be initialized to zero before the first call to this function. +* +* @post The `subgraph` will contain the nodes visited during the DFS. +* @post The `subgraph_node_indices` will contain the indices of the nodes in the subgraph. +* @post The `visited` array will reflect the nodes that have been visited. +* @post The `graph_size` will reflect the cumulative size of the nodes in the subgraph. +* +* @exception None +* +* @return None +*/ +void DFS_other(const onnx::GraphProto &g,onnx::GraphProto &subgraph, std::vector &sugraph_node_index, + int* visited, const onnx::NodeProto& start_node, + int node_index,std::vector& adjacency_list, int depth_in + ,float& graph_size, + float max_graph_size) +{ + int depth_out = depth_in + 1; + *subgraph.add_node()=start_node; + visited[node_index]=1; + sugraph_node_index.push_back(node_index); + float node_size = calculate_node_size(g, node_index); + graph_size += node_size; + if(graph_size>max_graph_size) + { + std::cout<<"graph size exceed max size!"<& otherSubgraphs, Device& d, int* visited, + std::vector& adjacency_list,PartitionStrategy strategy) +{ + int max_subgraph_size = d.max_subgraph_size; + std::vector support_op; + std::vector prefer_op; + switch(strategy) { + case SPILTE_CPU_STRUCTURE_FIRST:{ + support_op=d.getCPUSupportOp(); + break; + } + case SPILTE_NPU_STRUCTURE_FIRST:{ + support_op=d.getNPUSupportOp(); + prefer_op=d.getNPUPreferOp(); + break; + } + default: + break; + } + for(int i=0;i sugraph_node_index; + const auto& node=g.node(i); + int depth = 0; + float graph_size = 0; + DFS(g,subgraph,sugraph_node_index,visited,node,i,adjacency_list,support_op, prefer_op,depth, graph_size, max_subgraph_size); + std::cout<<"graph_size: "< sugraph_node_index; + const auto& node=g.node(i); + DFS_other(g,subgraph,sugraph_node_index,visited,node,i,adjacency_list, depth, graph_size, max_subgraph_size); + std::cout<<"graph_size:"<& otherSubgraphs, Device& d, int* visited, + std::vector& adjacency_list,PartitionStrategy strategy) +{ + float max_subgraph_size = d.max_subgraph_size; + std::vector support_op; + std::vector prefer_op; + support_op=d.getNPUSupportOp(); + prefer_op=d.getNPUPreferOp(); + onnx::GraphProto temp_graph; + int end_flag = 0; + int node_count = 0; + float temp_graph_size = 0; + while(!end_flag) + { + float node_size = calculate_node_size(g, node_count); + if(temp_graph.node_size()!= 0) + { + if((std::find(support_op.begin(), support_op.end(), g.node(node_count).op_type()) != support_op.end())&&temp_graph.node_size()<=max_subgraph_size) + { + *temp_graph.add_node() = g.node(node_count); + temp_graph_size += node_size; + if(temp_graph_size>max_subgraph_size) + { + std::cout<<"graph size exceed max size!"<max_subgraph_size) + { + std::cout<<"graph size exceed max size!"<>& strongly_connected_subgraphs,int* DFN, + int* LOW, std::vector& stack_subgraphs, std::vector>& successors_Subgraphs) +{ + int rank = depth + 1; + DFN[index] = LOW[index] = rank;// initialize DFN and LOW to 0 + stack_subgraphs.push_back(index); + for(const auto& successor : successors_Subgraphs[index]) + { + if(DFN[successor] == 0)//the successor is not visited + { + Tarjan(successor, rank,strongly_connected_subgraphs, DFN, LOW, stack_subgraphs, successors_Subgraphs);//visit successor + LOW[index] = std::min(LOW[index], LOW[successor]); + } + else if(std::find(stack_subgraphs.begin(),stack_subgraphs.end(),successor) != stack_subgraphs.end()) + { + LOW[index] = std::min(LOW[index], DFN[successor]); + } + } + if(LOW[index] == DFN[index])//if this node is the smallest root of the strongly connected component subtree, then subsequent nodes are popped out of the stack and the obtained strongly connected components are saved. + { + auto it = stack_subgraphs.end() - 1; + std:: vector strongly_connected; + while(*it != index) + { + strongly_connected.insert(strongly_connected.begin(), *it); + stack_subgraphs.pop_back(); + it = stack_subgraphs.end() - 1; + } + strongly_connected.insert(strongly_connected.begin(), *it); + + if(strongly_connected.size() > 1) + { + strongly_connected_subgraphs.push_back(strongly_connected); + } + stack_subgraphs.pop_back();//pop + + } +} +/** +* @brief Calculate the rank of each node in the merged graph formed by the given strongly connected components. +* The rank is determined based on the topological order of the nodes. +* +* @param [in] strongly_connected A vector containing indices of strongly connected components. +* @param [in] Subgraphs A vector of ONNX GraphProtos representing the main subgraphs. +* @param [in] otherSubgraphs A vector of ONNX GraphProtos representing additional subgraphs. +* +* @pre The `strongly_connected` vector should contain valid indices for `Subgraphs` and `otherSubgraphs`. +* @pre The `Subgraphs` and `otherSubgraphs` vectors should be properly initialized with ONNX GraphProtos. +* +* @post The `node_rank_list` vector will contain the nodes of the merged graph with their respective ranks. +* +* @exception None +* +* @return A vector of `graph_adjacency_node` structures containing the nodes and their ranks. +*/ +std::vector calculate_node_rank( + std::vector& strongly_connected, + std::vector& Subgraphs, + std::vector& otherSubgraphs + ) +{ + onnx::GraphProto merged_graph; + std::vector node_rank_list; + for(const auto& index : strongly_connected) + { + if(index < int(Subgraphs.size())) + { + mergeGraphs(merged_graph, Subgraphs[index]); + } + else + { + mergeGraphs(merged_graph, otherSubgraphs[index - Subgraphs.size()]); + } + } + int index = 0; + for(const auto& node : merged_graph.node()) + { + graph_adjacency_node node_rank; + node_rank.name = node.name(); + node_rank.index = index; + node_rank.rank = -1; + node_rank_list.push_back(node_rank); + index ++; + } + int sort_count=0; + int finished_flag=0; + while(!finished_flag) + { + finished_flag=1; + if(sort_count==0) + { + for(int i=0; i= 0 && node_rank_list[i].rank < sort_count){continue;}////If it has already been sorted, skip this subgraph + for(const auto& input : merged_graph.node(i).input())////traveres all inputs of this subgraph + { + for(int j=0; j< merged_graph.node_size(); j++)////examint if the input is the output of j th subgraph + { + for(const auto& output : merged_graph.node(j).output()) + { + if(output==input) + { + if((node_rank_list[j].rank < 0 || node_rank_list[j].rank >= sort_count))//the j th subgraph has not been sorted + { + find_flag=1; + break; + } + } + } + if(find_flag){break;} + + } + if(find_flag){break;} + } + if(!find_flag) + { + node_rank_list[i].rank=sort_count; + } + else {node_rank_list[i].rank=sort_count+1;finished_flag=0;} + } + } + sort_count++; + } + return node_rank_list; +} +/** +* @brief Calculate the rank of each node in the merged graph formed by the given strongly connected components. +* The rank is determined based on the topological order of the nodes. Compared with calculate_node_rank, this function has different input parameters. +* +* @param [in] strongly_connected A vector containing indices of strongly connected components. +* @param [in] Subgraphs A vector of ONNX GraphProtos representing the main subgraphs. +* @param [in] otherSubgraphs A vector of ONNX GraphProtos representing additional subgraphs. +* @param [in] subgraph_size The size of the Subgraphs vector. +* @param [in] other_subgraph_size The size of the otherSubgraphs vector. +* +* @pre The `strongly_connected` vector should contain valid indices for `Subgraphs` and `otherSubgraphs`. +* @pre The `Subgraphs` and `otherSubgraphs` vectors should be properly initialized with ONNX GraphProtos. +* @pre `subgraph_size` should be equal to the size of the `Subgraphs` vector. +* @pre `other_subgraph_size` should be equal to the size of the `otherSubgraphs` vector. +* +* @post The `node_rank_list` vector will contain the nodes of the merged graph with their respective ranks. +* +* @exception None +* +* @return A vector of `graph_adjacency_node` structures containing the nodes and their ranks. +*/ +std::vector calculate_node_rank_v2( + std::vector& strongly_connected, + std::vector& Subgraphs, + std::vector& otherSubgraphs, + int subgraph_size, + int other_subgraph_size + ) +{ + onnx::GraphProto merged_graph; + std::vector node_rank_list; + for(const auto& index : strongly_connected) + { + if(index < subgraph_size) + { + mergeGraphs(merged_graph, Subgraphs[index]); + } + else + { + mergeGraphs(merged_graph, otherSubgraphs[index - subgraph_size]); + } + } + int index = 0; + for(const auto& node : merged_graph.node()) + { + graph_adjacency_node node_rank; + node_rank.name = node.name(); + node_rank.index = index; + node_rank.rank = -1; + node_rank_list.push_back(node_rank); + index ++; + } + int sort_count=0; + int finished_flag=0; + while(!finished_flag) + { + finished_flag=1; + if(sort_count==0) + { + for(int i=0; i= 0 && node_rank_list[i].rank < sort_count){continue;} + for(const auto& input : merged_graph.node(i).input())////traverses all inputs of this subgraph + { + for(int j=0; j< merged_graph.node_size(); j++)///examint if the input is the output of j th subgraph + { + for(const auto& output : merged_graph.node(j).output()) + { + if(output==input) + { + if((node_rank_list[j].rank < 0 || node_rank_list[j].rank >= sort_count))//the j th subgraph has not been sorted + { + find_flag=1; + break; + } + } + } + if(find_flag){break;} + + } + if(find_flag){break;} + } + if(!find_flag) + { + node_rank_list[i].rank=sort_count; + } + else {node_rank_list[i].rank=sort_count+1;finished_flag=0;} + } + } + sort_count++; + } + return node_rank_list; +} +/** +* @brief Calculate the rank of each node in the given merged ONNX graph. +* The rank is determined based on the topological order of the nodes. +* This function is only used to calculate the rank of the nodes in a single graph, especially the original graph +* +* @param [in] merged_graph The ONNX GraphProto representing the merged graph. +* @param [out] node_rank_list A vector of `graph_adjacency_node` structures to store the nodes and their ranks. +* +* @pre The `merged_graph` should be a valid ONNX GraphProto. +* @pre The `node_rank_list` should be an empty vector or properly initialized. +* +* @post The `node_rank_list` vector will contain the nodes of the merged graph with their respective ranks. +* +* @exception None +* +* @return None +*/ +void calculate_node_rank_v3( + const onnx::GraphProto& merged_graph, + std::vector& node_rank_list + ) +{ + int index = 0; + for(const auto& node : merged_graph.node()) + { + graph_adjacency_node node_rank; + node_rank.name = node.name(); + node_rank.index = index; + node_rank.rank = -1; + node_rank_list.push_back(node_rank); + index ++; + } + int sort_count=0; + int finished_flag=0; + while(!finished_flag) + { + finished_flag=1; + if(sort_count==0) + { + for(int i=0; i= 0 && node_rank_list[i].rank < sort_count){continue;} + for(const auto& input : merged_graph.node(i).input())////traverses all inputs of this subgraph + { + for(int j=0; j< merged_graph.node_size(); j++)///examint if the input is the output of j th subgraph + { + for(const auto& output : merged_graph.node(j).output()) + { + if(output==input) + { + if((node_rank_list[j].rank < 0 || node_rank_list[j].rank >= sort_count))//the j th subgraph has not been sorted + { + find_flag=1; + break; + } + } + } + if(find_flag){break;} + + } + if(find_flag){break;} + } + if(!find_flag) + { + node_rank_list[i].rank=sort_count; + } + else {node_rank_list[i].rank=sort_count+1;finished_flag=0;} + } + } + sort_count++; + } +} +/** +* @brief Determine the cut ranks in the given list of SCC (Strongly Connected Component) node ranks. +* A cut rank is defined as a rank where no node exists, but there is at least one node at the next rank. +* +* @param [in] scc_node_rank A vector of `graph_adjacency_node` structures representing the nodes and their ranks. +* +* @pre The `scc_node_rank` vector should be properly initialized and contain valid node ranks. +* +* @post The function does not modify the `scc_node_rank` vector. +* +* @exception None +* +* @return A vector of integers representing the cut ranks. +*/ +std::vector get_cut_rank_v2(std::vector& scc_node_rank) +{ + std::vector cut_rank_list; + int min_cut_rank = -1; + int max_rank = 0; + //get min + for(int i=0; imax_rank) + { + max_rank = scc_node_rank[i].rank; + } + } + int find_flag = 1; + while(find_flag) + { + min_cut_rank ++; + int temp_find_flag = 0; + for(int i=0; i>& strongly_connected_subgraphs, + std::vector& Subgraphs, + std::vector& otherSubgraphs, + const onnx::GraphProto& g +) +{ + int subgraph_size = Subgraphs.size(); + std::vector node_rank_list; + calculate_node_rank_v3(g, node_rank_list); + for(auto& strongly_connected : strongly_connected_subgraphs) + for(const auto scc_index : strongly_connected) + { + onnx::GraphProto scc_graph; + if(scc_index < subgraph_size) + { + scc_graph = Subgraphs[scc_index]; + } + else + { + scc_graph = otherSubgraphs[scc_index - subgraph_size]; + } + std::vector scc_node_rank; + for(int i=0; i< scc_graph.node_size(); i++) + { + for(int j = 0; j < int(node_rank_list.size()); j++) + { + if(scc_graph.node(i).name() == node_rank_list[j].name) + { + scc_node_rank.push_back(node_rank_list[j]); + break; + } + } + } + std::vector cut_rank = get_cut_rank_v2(scc_node_rank); + onnx::GraphProto temp_graph_upper; + int node_in_upper = 0; + for(int i=0; i temp_graph_upper_adder_list; + int record_i = 0; + std::cout<<"node size: "<= cut_rank[0]) + {record_i = i + 1;} + else{record_i = i;} + if(temp_graph_upper_adder.node_size() > 0) + { + temp_graph_upper_adder_list.push_back(temp_graph_upper_adder); + temp_graph_upper_adder.clear_node(); + } + break; + } + if(i == scc_graph.node_size() - 1 && temp_graph_upper_adder.node_size()>0) + { + temp_graph_upper_adder_list.push_back(temp_graph_upper_adder); + temp_graph_upper_adder.clear_node(); + } + } + std::cout<<"loop ended:temp graph upper adder size: "< 1) + { + for(int i = 1; i< int(temp_graph_upper_adder_list.size()); i++) + { + if(scc_index < subgraph_size) + { + Subgraphs.push_back(temp_graph_upper_adder_list[i]); + } + else + { + otherSubgraphs.push_back(temp_graph_upper_adder_list[i]); + } + } + } + std::cout<<"scc index"<= cut_rank[i]&& scc_node_rank[j].rank < cut_rank[i+1]) + { + *temp_graph_lower.add_node() = scc_graph.node(j); + } + } + if(scc_index < subgraph_size) + { + if(temp_graph_lower.node_size()>0) + { + Subgraphs.push_back(temp_graph_lower); + } + + } + else + { + if(temp_graph_lower.node_size()>0) + { + otherSubgraphs.push_back(temp_graph_lower); + } + } + } + onnx::GraphProto temp_graph_lower; + for(int j=0; j= cut_rank[cut_rank.size() -1]) + { + *temp_graph_lower.add_node() = scc_graph.node(j); + } + } + if(scc_index < subgraph_size) + { + if(temp_graph_lower.node_size()>0) + { + Subgraphs.push_back(temp_graph_lower); + } + + } + else + { + if(temp_graph_lower.node_size()>0) + { + otherSubgraphs.push_back(temp_graph_lower); + } + } + } + for(int i=Subgraphs.size() - 1; i>=0; i--) + { + if(Subgraphs[i].node_size() == 0) + { + Subgraphs.erase(Subgraphs.begin()+i); + } + } + for(int i=otherSubgraphs.size() - 1; i>=0; i--) + { + if(otherSubgraphs[i].node_size() == 0) + { + otherSubgraphs.erase(otherSubgraphs.begin()+i); + } + } +} +/** +* @brief Eliminate strongly connected components in the graph and partition them into individual subgraphs. +* +* @param [in] strongly_connected_subgraphs List of indices representing strongly connected components. +* @param [in,out] Subgraphs List of subgraphs that will be updated. +* @param [in,out] otherSubgraphs List of other subgraphs that will be updated. +* @param [in] g The original graph from which strongly connected components are derived. +* @pre The input graph `g` should be properly initialized and contain nodes. +* @post The `Subgraphs` and `otherSubgraphs` lists will be updated with individual nodes from each strongly connected component. +* @exception None +* @return None +*/ +void eliminate_scc_v3( + std::vector>& strongly_connected_subgraphs, + std::vector& Subgraphs, + std::vector& otherSubgraphs, + const onnx::GraphProto& g +) +{ + int subgraph_size = Subgraphs.size(); + for(int i = 0; i < int(strongly_connected_subgraphs.size()); i++) + { + for(const auto scc_index : strongly_connected_subgraphs[i]) + { + std::cout<<"scc index: "<=0; i--) + { + if(Subgraphs[i].node_size() == 0) + { + Subgraphs.erase(Subgraphs.begin()+i); + } + } + for(int i=otherSubgraphs.size() - 1; i>=0; i--) + { + if(otherSubgraphs[i].node_size() == 0) + { + otherSubgraphs.erase(otherSubgraphs.begin()+i); + } + } +} +/** +* @brief Determine the graph type based on the given index and return the corresponding graph. +* +* @param [in] index The index of the graph to determine. +* @param [in] Subgraphs List of subgraphs. +* @param [in] otherSubgraphs List of other subgraphs. +* @param [in] subgraph_size The size of the Subgraphs list. +* @pre The `index` should be a valid index within the combined range of `Subgraphs` and `otherSubgraphs`. +* @post None +* @exception None +* @return The graph corresponding to the given index. +*/ +onnx::GraphProto determinegraphtype_v2( + int index, + std::vector& Subgraphs, + std::vector& otherSubgraphs, + int subgraph_size +) +{ + if(index < subgraph_size) + { + return Subgraphs[index]; + } + else + { + return otherSubgraphs[index - subgraph_size]; + } +} +/** +* @brief Find pairs of strongly connected subgraphs based on input and output tensors. +* +* @param [in] strongly_connected_subgraphs List of strongly connected subgraphs. +* @param [in] Subgraphs List of subgraphs. +* @param [in] otherSubgraphs List of other subgraphs. +* @param [in] graphs_inputs List of input tensors for each graph. +* @param [in] graphs_outputs List of output tensors for each graph. +* @param [out] sccs_pairs List of pairs of strongly connected subgraphs. +* @pre The input lists should be properly initialized and contain valid data. +* @post The `sccs_pairs` list will contain pairs of indices representing connected subgraphs. +* @exception None +* @return None +*/ +void find_subgraph_pair_v2( + std::vector>& strongly_connected_subgraphs, + std::vector& Subgraphs, + std::vector& otherSubgraphs, + std::vector>& graphs_inputs, + std::vector>& graphs_outputs, + std::vector>>& sccs_pairs +) +{ + int count = 0; + for(const auto& strongly_connected :strongly_connected_subgraphs) + { + std::vector scc_graphs; + std::vector> scc_graphs_inputs; + std::vector> scc_graphs_outputs; + for(const auto& index : strongly_connected) + { + std::unordered_set graph_inputs = graphs_inputs[index]; + std::unordered_set graph_outputs = graphs_outputs[index]; + scc_graphs_inputs.push_back(graph_inputs); + scc_graphs_outputs.push_back(graph_outputs); + } + std::vector> scc_pairs; + std::vector is_pushed; + for(int j = 0; j < int(strongly_connected.size()); j++) + { + is_pushed.push_back(0); + } + for(int i = 0; i < int(strongly_connected.size()); i++) + { + for(const auto& graph_input : scc_graphs_inputs[i]) + { + for(int j = i + 1; j < int(strongly_connected.size()); j++) + { + std::vector scc_pair; + if(scc_graphs_outputs[j].find(graph_input)!=scc_graphs_outputs[j].end()&& is_pushed[j]==0) + { + for(const auto& graph_output : scc_graphs_outputs[i]) + { + if(scc_graphs_inputs[j].find(graph_output)!=scc_graphs_inputs[j].end()) + { + scc_pair.push_back(strongly_connected[i]); + scc_pair.push_back(strongly_connected[j]); + scc_pairs.push_back(scc_pair); + is_pushed[j]=1; + is_pushed[i]=1; + break; + } + } + } + if(is_pushed[i]==1) + { + break; + } + } + if(is_pushed[i]==1) + { + break; + } + } + } + if(scc_pairs.size() != 0) + { + sccs_pairs.push_back(scc_pairs); + } + count ++; + } + for(const auto& scc_pairs : sccs_pairs) + { + std::cout << "scc pair:"; + for(const auto& scc_pair : scc_pairs) + { + + for(const auto& scc_id : scc_pair) + { + std::cout << scc_id << " "; + } + std::cout << ";"; + } + std::cout << std::endl; + } +} +/** +* @brief Cut a pair of subgraphs into upper and lower parts based on node rank. +* +* @param [in] Subgraphs List of subgraphs. +* @param [in] otherSubgraphs List of other subgraphs. +* @param [in] graphs_inputs List of input tensors for each graph. +* @param [in] graphs_outputs List of output tensors for each graph. +* @param [in] scc_pair Pair of subgraph indices to be cut. +* @param [out] scc_pair_cut List of cut subgraphs (upper and lower parts of master graph and slave graph). +* @param [in] subgraph_size Size of subgraph. +* @pre The input lists should be properly initialized and contain valid data. +* @post The `scc_pair_cut` list will contain the cut subgraphs. +* @exception None +* @return A vector containing the index of the master graph and the cut rank. +*/ +std::vector cut_pair( + std::vector& Subgraphs, + std::vector& otherSubgraphs, + std::vector>& graphs_inputs, + std::vector>& graphs_outputs, + std::vector& scc_pair, + std::vector& scc_pair_cut, + int subgraph_size +) +{ + std::vector pair_node_list = calculate_node_rank(scc_pair, Subgraphs,otherSubgraphs); + int master_graph = 0; + for(const auto& node : pair_node_list) + { + if(node.rank==0) + { + int find_flag = -1; + onnx::GraphProto graph_temp = determinegraphtype_v2(scc_pair[0],Subgraphs, otherSubgraphs,subgraph_size); + for(const auto& graph_node : graph_temp.node()) + { + if(graph_node.name()==node.name) + { + find_flag = 1; + break; + } + } + if(find_flag == 1) + { + master_graph = 0; + break; + } + else{master_graph = 1;break;} + } + } + int slave_graph = 1 - master_graph; + //find the position where master and slave graph connect + int cut_rank = -1; + for(const auto& output : graphs_outputs[scc_pair[slave_graph]]) + { + for(const auto& input : graphs_inputs[scc_pair[master_graph]]) + { + + if(input.name ==output.name) + { + int node_index = 0; + onnx::GraphProto graph_temp = determinegraphtype_v2(scc_pair[slave_graph],Subgraphs, otherSubgraphs,subgraph_size); + for(const auto& graph_node : graph_temp.node()) + { + int update_node_rank = 0; + for(const auto& output_node : graph_node.output()) + { + if(output_node==output.name) + { + if(slave_graph==0) + { + if(cut_rank==-1||cut_rank>pair_node_list[node_index].rank) + { + cut_rank = pair_node_list[node_index].rank; + } + } + else + { + onnx::GraphProto graph_temp_1 = determinegraphtype_v2(scc_pair[master_graph], Subgraphs, otherSubgraphs,subgraph_size); + if(cut_rank==-1||cut_rank>pair_node_list[node_index+ graph_temp_1.node_size()].rank) + { + cut_rank = pair_node_list[node_index+ graph_temp_1.node_size()].rank; + } + } + update_node_rank = 1; + break; + } + } + if(update_node_rank == 1) + { + break; + } + node_index++; + } + break; + } + } + } + //cut master graph according to the rank + onnx::GraphProto master_upper; + onnx::GraphProto master_lower; + int node_index = 0; + onnx::GraphProto graph_temp = determinegraphtype_v2(scc_pair[master_graph],Subgraphs, otherSubgraphs,subgraph_size); + for(const auto& node : graph_temp.node()) + { + int node_rank; + if(master_graph == 0) + { + node_rank = pair_node_list[node_index].rank; + } + else + { + onnx::GraphProto graph_temp_2 = determinegraphtype_v2(scc_pair[slave_graph],Subgraphs, otherSubgraphs,subgraph_size); + node_rank = pair_node_list[node_index+ graph_temp_2.node_size()].rank; + } + if(node_rank return_value; + return_value.push_back(master_graph); + return_value.push_back(cut_rank); + return return_value; +} +/** +* @brief Eliminate pairs of subgraphs by cutting them and updating the subgraph lists. +* +* @param [in,out] Subgraphs List of subgraphs to be processed and updated. +* @param [in,out] otherSubgraphs List of other subgraphs to be processed and updated. +* @param [in] graphs_inputs List of input tensors for each graph. +* @param [in] graphs_outputs List of output tensors for each graph. +* @param [in] strongly_connected_subgraphs List of strongly connected subgraphs. +* @param [in] subgraph_size Size of subgraph. +* @pre The input lists should be properly initialized and contain valid data. +* @post The `Subgraphs` and `otherSubgraphs` lists will be updated with cut subgraphs. +* @exception None +* @return None +*/ +void eliminate_pair_v2( + std::vector& Subgraphs, + std::vector& otherSubgraphs, + std::vector>& graphs_inputs, + std::vector>& graphs_outputs, + std::vector>& strongly_connected_subgraphs, + int subgraph_size +) +{ + int original_node_size = 0; + for(auto& subgraph : Subgraphs) + { + original_node_size += subgraph.node_size(); + } + for(auto& subgraph : otherSubgraphs) + { + original_node_size += subgraph.node_size(); + } + std::vector>> sccs_pairs; + find_subgraph_pair_v2(strongly_connected_subgraphs, Subgraphs, otherSubgraphs, graphs_inputs, graphs_outputs, sccs_pairs); + for(auto& scc_pairs : sccs_pairs) + { + for(auto& scc_pair : scc_pairs) + { + std::vector scc_pair_cut; + cut_pair(Subgraphs, otherSubgraphs, graphs_inputs, graphs_outputs, scc_pair, scc_pair_cut, subgraph_size); + if(scc_pair[0] < subgraph_size) + { + Subgraphs[scc_pair[0]] = scc_pair_cut[0]; + Subgraphs.push_back(scc_pair_cut[1]); + } + else + { + otherSubgraphs[scc_pair[0]-subgraph_size] = scc_pair_cut[0]; + otherSubgraphs.push_back(scc_pair_cut[1]); + } + + if(scc_pair[1] < subgraph_size) + { + Subgraphs[scc_pair[1]] = scc_pair_cut[2]; + } + else + { + otherSubgraphs[scc_pair[1]-subgraph_size] = scc_pair_cut[2]; + } + } + } + for(int i=Subgraphs.size() - 1; i>=0; i--) + { + if(Subgraphs[i].node_size() == 0) + { + Subgraphs.erase(Subgraphs.begin()+i); + } + } + for(int i=otherSubgraphs.size() - 1; i>=0; i--) + { + if(otherSubgraphs[i].node_size() == 0) + { + otherSubgraphs.erase(otherSubgraphs.begin()+i); + } + } +} +/** +* @brief Find the successor or predecessor subgraph with the least number of nodes. +* +* @param [in] index Index of the current subgraph. +* @param [in] successor List of successor indices. +* @param [in] predecessor List of predecessor indices. +* @param [in] Subgraphs List of subgraphs. +* @param [in] otherSubgraphs List of other subgraphs. +* @pre The input lists should be properly initialized and contain valid data. +* @post None +* @exception None +* @return Index of the successor or predecessor subgraph with the least number of nodes, or -1 if no such subgraph exists. +*/ +int find_min_size(int index, std::vector& successor, std::vector& predecessor,std::vector &Subgraphs, std::vector &otherSubgraphs)//find the successor or predecessor with the least nodes +{ + std::vector size_list; + int min_index = -1; + int min_size = 10000; + for(int i = 0; i < int(successor.size()); i++) + { + std::cout<< "successor: "<= int(Subgraphs.size())&& index >= int(Subgraphs.size())) ) + { + if(successor[i] < int(Subgraphs.size())) + { + tempgraph = Subgraphs[successor[i]]; + } + else + { + tempgraph = otherSubgraphs[successor[i]-int(Subgraphs.size())]; + } + } + else + { + continue; + } + int size = int(tempgraph.node_size()); + std::cout<< " size:"<= int(Subgraphs.size())&& index >= int(Subgraphs.size()))) + { + if(predecessor[i] < int(Subgraphs.size())) + { + tempgraph = Subgraphs[predecessor[i]]; + } + else + { + tempgraph = otherSubgraphs[predecessor[i] - int(Subgraphs.size())]; + } + } + else + { + continue; + } + int size = int(tempgraph.node_size()); + std::cout<< " size:"< &node_io_size) { + std::unordered_set IOvalueNames = getIOvalue(g); + int* visited = (int*)malloc(g.node_size()*sizeof(int)); + std::vector adjacency_list=get_adjancency_list(g, visited); + std::vector otherSubgraphs; + determine_subgraphs_v2(g,otherSubgraphs, d, visited, adjacency_list,strategy); + std::cout<<"Partition Done"<().swap(adjacency_list); + int node_sum = 0; + // traverse the structures and print each element + std::ofstream outFile("./subgraphs_1.txt"); + if (!outFile.is_open()) { + std::cerr << "Error opening file." << std::endl; + exit(0); + } + int id = 0; + for (const auto& vec : Subgraphs) { + outFile << " subgraph" << id << ":"; + for (const auto& node : vec.node()) { + outFile << node.name() << " "; + } + id++; + outFile << std::endl; + node_sum += vec.node_size(); + } + int id_record = id; + std::ofstream outFile_2("./subgraphs_2.txt"); + if (!outFile_2.is_open()) { + std::cerr << "Error opening file." << std::endl; + exit(0); + } + std::cout << "before:" << std::endl; + for (const auto& vec : otherSubgraphs) { + outFile_2 << " subgraph" << id << ":"; + for (const auto& node : vec.node()) { + outFile_2 << node.name() << " "; + } + id++; + outFile_2 << std::endl; + node_sum += vec.node_size(); + } + std::vector> subgraphs_2_input_nodes_; + std::vector> subgraphs_2_nodes_; + for (const auto& sg : otherSubgraphs) { + std::unordered_set graphInputs; + determineGraphInput(sg, IOvalueNames, graphInputs); + std::unordered_set graphInputsNodes; + for (const auto& input : graphInputs) { + auto nodename = findInputNode(g, input.name); + if (nodename != "") { + graphInputsNodes.insert(nodename); + } + } + subgraphs_2_input_nodes_.push_back(graphInputsNodes); + subgraphs_2_nodes_.push_back(collectNodeNames(sg)); + } + int* is_merged = (int *)malloc(otherSubgraphs.size() * sizeof(int)); + for(int i=0;i> subgraphs_1_inputs; + std::vector> subgraphs_1_input_nodes; + std::vector> subgraphs_1_nodes; + for (const auto& sg : Subgraphs) { + std::unordered_set graphInputs; + determineGraphInput(sg, IOvalueNames, graphInputs); + subgraphs_1_inputs.push_back(graphInputs); + std::unordered_set graphInputsNodes; + for (const auto& input : graphInputs) { + auto nodename = findInputNode(g, input.name); + if (nodename != "") { + graphInputsNodes.insert(nodename); + } + } + subgraphs_1_input_nodes.push_back(graphInputsNodes); + subgraphs_1_nodes.push_back(collectNodeNames(sg)); + } + + std::vector> subgraphs_2_inputs; + std::vector> subgraphs_2_input_nodes; + std::vector> subgraphs_2_nodes; + for (const auto& sg : otherSubgraphs) { + std::unordered_set graphInputs; + determineGraphInput(sg, IOvalueNames, graphInputs); + subgraphs_2_inputs.push_back(graphInputs); + std::unordered_set graphInputsNodes; + for (const auto& input : graphInputs) { + auto nodename = findInputNode(g, input.name); + if (nodename != "") { + graphInputsNodes.insert(nodename); + } + } + subgraphs_2_input_nodes.push_back(graphInputsNodes); + subgraphs_2_nodes.push_back(collectNodeNames(sg)); + } + std::vector> subgraphs_1_outputs; + + int node_number=0; + + for (const auto& sg : Subgraphs) { + std::unordered_set graphOutputs; + node_number+=sg.node_size(); + determineGraphOutput(g,sg, subgraphs_1_inputs, subgraphs_2_inputs, graphOutputs); + subgraphs_1_outputs.push_back(graphOutputs); + } + std::vector> subgraphs_2_outputs; + for (const auto& sg : otherSubgraphs) { + std::unordered_set graphOutputs; + node_number+=sg.node_size(); + determineGraphOutput(g,sg, subgraphs_1_inputs, subgraphs_2_inputs, graphOutputs); + subgraphs_2_outputs.push_back(graphOutputs); + } + int graph_node_size_minus_constant = g.node_size(); + for(const auto& node: g.node()) + { + if(node.op_type() == "Constant") + { + graph_node_size_minus_constant--; + } + } + std::cout<<"total number of nodes in subgraphs:"<> graphs_inputs; + graphs_inputs.insert(graphs_inputs.end(),subgraphs_1_inputs.begin(),subgraphs_1_inputs.end()); + graphs_inputs.insert(graphs_inputs.end(),subgraphs_2_inputs.begin(),subgraphs_2_inputs.end()); + std::vector> graphs_outputs; + graphs_outputs.insert(graphs_outputs.end(),subgraphs_1_outputs.begin(),subgraphs_1_outputs.end()); + graphs_outputs.insert(graphs_outputs.end(),subgraphs_2_outputs.begin(),subgraphs_2_outputs.end()); + + std::vector> predecessors_Subgraphs(graphs_inputs.size()); + std::vector> successors_Subgraphs(graphs_inputs.size()); + for(int i=0; i predecessors; + for(const auto& g_input : graphs_inputs[i]) + { + for(int j=0; j< int(graphs_outputs.size());j++) + { + if((graphs_outputs[j].find(g_input)!=graphs_outputs[j].end())) + { + predecessors.push_back(j); + } + } + } + if(predecessors.size() == 0) + { + std::cout<<"subgraph "<> strongly_connected_subgraphs; + int* DFN = (int *)malloc(graphs_inputs.size() * sizeof(int)); + int* LOW = (int *)malloc(graphs_inputs.size() * sizeof(int)); + for(int i = 0; i < int(graphs_inputs.size()); i++) + { + DFN[i] = 0; + LOW[i] = 0; + } + for( int temp_count = 0 ; temp_count < int(predecessors_Subgraphs.size()); temp_count ++) + { + if(DFN[temp_count] == 0) + { + std::vector stack_subgraphs; + int depth = 0; + Tarjan(temp_count, depth, strongly_connected_subgraphs, DFN, + LOW, stack_subgraphs, successors_Subgraphs); + } + } + + std::string file_name_scc = "scc.txt"; + std::ofstream outfile_scc(file_name_scc); + outfile_scc << strongly_connected_subgraphs.size()< stack_subgraphs; + int depth = 0; + Tarjan(temp_count, depth, strongly_connected_subgraphs, DFN_, + LOW_, stack_subgraphs, successors_Subgraphs); + } + } + free(DFN_); + free(LOW_); + eliminate_scc_v2(strongly_connected_subgraphs, Subgraphs, otherSubgraphs, g); + ///////////////////// + strongly_connected_subgraphs.clear(); + predecessors_Subgraphs.clear(); + successors_Subgraphs.clear(); + std::vector>().swap(subgraphs_2_inputs); + std::vector>().swap(subgraphs_1_inputs); + std::vector>().swap(subgraphs_2_outputs); + std::vector>().swap(subgraphs_1_outputs); + std::vector>().swap(graphs_inputs); + std::vector>().swap(graphs_outputs); + for (const auto& sg : Subgraphs) { + std::unordered_set graphInputs; + determineGraphInput(sg, IOvalueNames, graphInputs); + subgraphs_1_inputs.push_back(graphInputs); + } + for (const auto& sg : otherSubgraphs) { + std::unordered_set graphInputs; + determineGraphInput(sg, IOvalueNames, graphInputs); + subgraphs_2_inputs.push_back(graphInputs); + } + node_number = 0; + for (const auto& sg : Subgraphs) { + std::unordered_set graphOutputs; + node_number+=sg.node_size(); + determineGraphOutput(g,sg, subgraphs_1_inputs, subgraphs_2_inputs, graphOutputs); + subgraphs_1_outputs.push_back(graphOutputs); + } + for (const auto& sg : otherSubgraphs) { + std::unordered_set graphOutputs; + node_number+=sg.node_size(); + determineGraphOutput(g,sg, subgraphs_1_inputs, subgraphs_2_inputs, graphOutputs); + subgraphs_2_outputs.push_back(graphOutputs); + } + graphs_inputs.insert(graphs_inputs.end(),subgraphs_1_inputs.begin(),subgraphs_1_inputs.end()); + graphs_inputs.insert(graphs_inputs.end(),subgraphs_2_inputs.begin(),subgraphs_2_inputs.end()); + graphs_outputs.insert(graphs_outputs.end(),subgraphs_1_outputs.begin(),subgraphs_1_outputs.end()); + graphs_outputs.insert(graphs_outputs.end(),subgraphs_2_outputs.begin(),subgraphs_2_outputs.end()); + for(int i=0; i predecessors; + for(const auto& g_input : graphs_inputs[i]) + { + for(int j=0; j< int(graphs_outputs.size());j++) + { + if((graphs_outputs[j].find(g_input)!=graphs_outputs[j].end())) + { + predecessors.push_back(j); + } + } + } + predecessors_Subgraphs.push_back(predecessors); + } + for(int i=0;i temp; + for(int j=0;j stack_subgraphs; + int depth = 0; + Tarjan(temp_count, depth, strongly_connected_subgraphs, DFN_2, + LOW_2, stack_subgraphs, successors_Subgraphs); + } + } + std::string file_name_scc2 = "scc2.txt"; + std::ofstream outfile_scc2(file_name_scc2); + for(const auto& scc : strongly_connected_subgraphs) + { + std::cout << "scc:"; + outfile_scc2 << "scc: "; + for(const auto& scc_id : scc) + { + outfile_scc2 << scc_id << " "; + } + outfile_scc2 << std::endl; + for(const auto& scc_id : scc) + { + std::cout << scc_id << " "; + outfile_scc2 << "subgraph" << scc_id << " input:"; + for(const auto& scc_input : graphs_inputs[scc_id]) + { + outfile_scc2 << scc_input.name << ";"; + } + outfile_scc2 << " output:"; + for(const auto& scc_output : graphs_outputs[scc_id]) + { + outfile_scc2 << scc_output.name << ";"; + } + outfile_scc2 << std::endl; + } + + std::cout << std::endl; + } + outfile_scc.close(); + free(DFN_2); + free(LOW_2); + //eliminate_scc_v2(strongly_connected_subgraphs, Subgraphs, otherSubgraphs, g); + int subgraph_size_2 = Subgraphs.size(); + int other_subgraph_size_2 = otherSubgraphs.size(); + std::vector eliminated_small_graph_id; + std::vector eliminated_small_graph_size; + std::vector eliminated_small_graph_size_2; + std::vector unmerged_graph_id; + for(int i =0; i< subgraph_size_2 + other_subgraph_size_2; i++) + { + std::cout << "i:" << i << std::endl; + if(i < subgraph_size_2) + { + if(Subgraphs[i].node_size() < 2) + { + int merge_id = find_min_size(i,successors_Subgraphs[i], predecessors_Subgraphs[i], Subgraphs, otherSubgraphs); + if(merge_id < subgraph_size_2&&merge_id >= 0) + { + mergeGraphs(Subgraphs[merge_id], Subgraphs[i]); + eliminated_small_graph_id.push_back(i); + eliminated_small_graph_size.push_back(Subgraphs[i].node_size()); + std::cout << "eliminating small graph "<= 0) + { + mergeGraphs(otherSubgraphs[merge_id - subgraph_size_2], Subgraphs[i]); + eliminated_small_graph_id.push_back(i); + eliminated_small_graph_size.push_back(Subgraphs[i].node_size()); + std::cout << "eliminating small graph "<= 0) + { + mergeGraphs(Subgraphs[merge_id], otherSubgraphs[i - subgraph_size_2]); + eliminated_small_graph_id.push_back(i); + eliminated_small_graph_size.push_back(otherSubgraphs[i - subgraph_size_2].node_size()); + std::cout << "eliminating small graph "<= 0) + { + mergeGraphs(otherSubgraphs[merge_id - subgraph_size_2], otherSubgraphs[i - subgraph_size_2]); + eliminated_small_graph_id.push_back(i); + eliminated_small_graph_size.push_back(otherSubgraphs[i - subgraph_size_2].node_size()); + std::cout << "eliminating small graph "<=0; i--) + { + if(std::find(unmerged_graph_id.begin(), unmerged_graph_id.end(), eliminated_small_graph_id[i]) != unmerged_graph_id.end()) + { + continue; + } + std::cout<1) + { + std::cout<<"eliminate Subgraphs"<1) + { + std::cout<<"eliminate otherSubgraphs"<>().swap(subgraphs_2_inputs); + std::vector>().swap(subgraphs_1_inputs); + std::vector>().swap(subgraphs_2_outputs); + std::vector>().swap(subgraphs_1_outputs); + std::vector>().swap(graphs_inputs); + std::vector>().swap(graphs_outputs); + for (const auto& sg : Subgraphs) { + std::unordered_set graphInputs; + determineGraphInput(sg, IOvalueNames, graphInputs); + subgraphs_1_inputs.push_back(graphInputs); + } + for (const auto& sg : otherSubgraphs) { + std::unordered_set graphInputs; + determineGraphInput(sg, IOvalueNames, graphInputs); + subgraphs_2_inputs.push_back(graphInputs); + } + node_number = 0; + for (const auto& sg : Subgraphs) { + std::unordered_set graphOutputs; + node_number+=sg.node_size(); + determineGraphOutput(g,sg, subgraphs_1_inputs, subgraphs_2_inputs, graphOutputs); + subgraphs_1_outputs.push_back(graphOutputs); + } + for (const auto& sg : otherSubgraphs) { + std::unordered_set graphOutputs; + node_number+=sg.node_size(); + determineGraphOutput(g,sg, subgraphs_1_inputs, subgraphs_2_inputs, graphOutputs); + subgraphs_2_outputs.push_back(graphOutputs); + } + graphs_inputs.insert(graphs_inputs.end(),subgraphs_1_inputs.begin(),subgraphs_1_inputs.end()); + graphs_inputs.insert(graphs_inputs.end(),subgraphs_2_inputs.begin(),subgraphs_2_inputs.end()); + graphs_outputs.insert(graphs_outputs.end(),subgraphs_1_outputs.begin(),subgraphs_1_outputs.end()); + graphs_outputs.insert(graphs_outputs.end(),subgraphs_2_outputs.begin(),subgraphs_2_outputs.end()); + for(int i=0; i predecessors; + for(const auto& g_input : graphs_inputs[i]) + { + for(int j=0; j< int(graphs_outputs.size());j++) + { + if((graphs_outputs[j].find(g_input)!=graphs_outputs[j].end())) + { + predecessors.push_back(j); + } + } + } + predecessors_Subgraphs.push_back(predecessors); + } + for(int i=0;i temp; + for(int j=0;j stack_subgraphs; + int depth = 0; + Tarjan(temp_count, depth, strongly_connected_subgraphs, DFN_3, + LOW_3, stack_subgraphs, successors_Subgraphs); + } + } + std::string file_name_scc3 = "scc3.txt"; + std::ofstream outfile_scc3(file_name_scc3); + for(const auto& scc : strongly_connected_subgraphs) + { + std::cout << "scc:"; + outfile_scc3 << "scc: "; + for(const auto& scc_id : scc) + { + outfile_scc3 << scc_id << " "; + } + outfile_scc3 << std::endl; + for(const auto& scc_id : scc) + { + std::cout << scc_id << " "; + outfile_scc3 << "subgraph" << scc_id << " input:"; + for(const auto& scc_input : graphs_inputs[scc_id]) + { + outfile_scc3 << scc_input.name << ";"; + } + outfile_scc3 << " output:"; + for(const auto& scc_output : graphs_outputs[scc_id]) + { + outfile_scc3 << scc_output.name << ";"; + } + outfile_scc3 << std::endl; + } + + std::cout << std::endl; + } + outfile_scc.close(); + free(DFN_3); + free(LOW_3); + std::cout << "node_num after cut " << node_num_all << std::endl; + if(node_num_all != g.node_size()) + { + std::cout << "num error!" << std::endl; + exit(0); + } + int count_cut_pair = 0; + while(1) + { + count_cut_pair ++; + if(count_cut_pair > 15) + { + std::cout << "cut pair error! So many times!" << std::endl; + exit(0); + break; + } + int subgraph_size = Subgraphs.size(); + std::vector> strongly_connected_subgraphs_all; + std::vector scc_all; + for(int i = 0; i < int(Subgraphs.size()) + int(otherSubgraphs.size()); i++) + { + scc_all.push_back(i); + } + strongly_connected_subgraphs_all.push_back(scc_all); + if(((count_cut_pair >1 &&count_cut_pair < 5)||(count_cut_pair >10 &&count_cut_pair < 13)) && strongly_connected_subgraphs.size() != 0) + { + std::cout <>().swap(subgraphs_2_inputs); + std::vector>().swap(subgraphs_1_inputs); + std::vector>().swap(subgraphs_2_outputs); + std::vector>().swap(subgraphs_1_outputs); + std::vector>().swap(graphs_inputs); + std::vector>().swap(graphs_outputs); + for (const auto& sg : Subgraphs) { + std::unordered_set graphInputs; + determineGraphInput(sg, IOvalueNames, graphInputs); + subgraphs_1_inputs.push_back(graphInputs); + } + for (const auto& sg : otherSubgraphs) { + std::unordered_set graphInputs; + determineGraphInput(sg, IOvalueNames, graphInputs); + subgraphs_2_inputs.push_back(graphInputs); + } + node_number = 0; + for (const auto& sg : Subgraphs) { + std::unordered_set graphOutputs; + node_number+=sg.node_size(); + determineGraphOutput(g,sg, subgraphs_1_inputs, subgraphs_2_inputs, graphOutputs); + subgraphs_1_outputs.push_back(graphOutputs); + } + for (const auto& sg : otherSubgraphs) { + std::unordered_set graphOutputs; + node_number+=sg.node_size(); + determineGraphOutput(g,sg, subgraphs_1_inputs, subgraphs_2_inputs, graphOutputs); + subgraphs_2_outputs.push_back(graphOutputs); + } + graphs_inputs.insert(graphs_inputs.end(),subgraphs_1_inputs.begin(),subgraphs_1_inputs.end()); + graphs_inputs.insert(graphs_inputs.end(),subgraphs_2_inputs.begin(),subgraphs_2_inputs.end()); + graphs_outputs.insert(graphs_outputs.end(),subgraphs_1_outputs.begin(),subgraphs_1_outputs.end()); + graphs_outputs.insert(graphs_outputs.end(),subgraphs_2_outputs.begin(),subgraphs_2_outputs.end()); + for(int i=0; i predecessors; + for(const auto& g_input : graphs_inputs[i]) + { + for(int j=0; j< int(graphs_outputs.size());j++) + { + if((graphs_outputs[j].find(g_input)!=graphs_outputs[j].end())) + { + predecessors.push_back(j); + } + } + } + predecessors_Subgraphs.push_back(predecessors); + } + for(int i=0;i temp; + for(int j=0;j stack_subgraphs; + int depth = 0; + Tarjan(temp_count, depth, strongly_connected_subgraphs, DFN_4, + LOW_4, stack_subgraphs, successors_Subgraphs); + } + } + std::string file_name_scc4 = "scc4.txt"; + std::ofstream outfile_scc4(file_name_scc4); + for(const auto& scc : strongly_connected_subgraphs) + { + std::cout << "scc4:"; + for(const auto& scc_id : scc) + { + std::cout << scc_id << " "; + outfile_scc4 << "subgraph" << scc_id << " input:"; + for(const auto& scc_input : graphs_inputs[scc_id]) + { + outfile_scc4 << scc_input.name << ";"; + } + outfile_scc4 << " output:"; + for(const auto& scc_output : graphs_outputs[scc_id]) + { + outfile_scc4 << scc_output.name << ";"; + } + outfile_scc4 << std::endl; + } + + std::cout << std::endl; + } + outfile_scc.close(); + free(DFN_4); + free(LOW_4); + std::cout << "node num in original graph: " << g.node_size() << std::endl; + std::cout << "node_num after cut " << node_num_all << std::endl; + if(node_num_all != g.node_size()) + { + std::cout << "num error!, time" < order_Subgraphs(graphs_inputs.size()); + std::vector issort_Subgraphs(graphs_inputs.size()); + while(!finished_flag) + { + finished_flag=1; + int changed_sort_flag=0; + if(sort_count==0) + { + changed_sort_flag=1; + for(int i=0; i=sub1_size?sub2_type:sub1_type)<<"subgraph"<<(i>=sub1_size?(i-sub1_size):i)<<": order"<=sub1_size?sub2_type:sub1_type)<<"subgraph"<<(i>=sub1_size?(i-sub1_size):i)<<": order"<=sub1_size?sub2_type:sub1_type)<<"subgraph"<<(i>=sub1_size?(i-sub1_size):i)<<": "; + for(auto element : predecessors_Subgraphs[i]) + { + std::cout << (element>=sub1_size?sub2_type:sub1_type)<<"subgraph"<<(element>=sub1_size?(element-sub1_size):element) <<"; "; + } + std::cout <=sub1_size?sub2_type:sub1_type)<<"subgraph"<<(i>=sub1_size?(i-sub1_size):i)<<": "; + for(auto element : successors_Subgraphs[i]) + { + std::cout << (element>=sub1_size?sub2_type:sub1_type)<<"subgraph"<<(element>=sub1_size?(element-sub1_size):element) <<"; "; + } + std::cout < +int DetermineStructure(const onnx::GraphProto& graph, Device &d,PartitionStrategy strategy) +{ + int node_index = 0; + std::vector> enabled_structure; + std::vector structure_temp; + while(node_index < graph.node_size()) + { + std::vector support_op; + const auto& node = graph.node(node_index); + switch (strategy) + { + case SPILTE_CPU_STRUCTURE_FIRST: + { + support_op =d.getCPUSupportOp(); + break; + } + case SPILTE_NPU_STRUCTURE_FIRST: + { + support_op =d.getNPUSupportOp(); + break; + } + default: + {break;} + } + if(std::find(support_op.begin(),support_op.end(),node.op_type())!=support_op.end()) + { + auto op_index=std::find(support_op.begin(),support_op.end(),node.op_type()); + structure_temp.push_back(*op_index); + } + else + { + if(structure_temp.size()>=3) + { + bool isequal=0; + for(const auto& structure : enabled_structure) + + { + if(std::equal(structure.begin(),structure.end(),structure_temp.begin(),structure_temp.end())) + { + isequal=1; + break; + } + } + if(isequal==0) + { + enabled_structure.push_back(structure_temp); + } + } + if(structure_temp.size()!=0){ + structure_temp.clear(); + } + + } + node_index++; + } + + for(const auto& structure : enabled_structure) + { + std::cout<<"{"; + for(const auto& op : structure) + { + std::cout <<"\""<< op << "\","; + } + std::cout<<"},"< +#include +#include "graph.h" +#include "partition.h" +#include "Python.h" + +int main(int argc, char* argv[]) { + std::string onnxFile; + if (argc > 1) { + for (int i = 1; i < argc; ++i) { + std::string arg = argv[i]; + if (arg.substr(0, 7) == "--onnx=") { + onnxFile = arg.substr(7); + std::cout << "ONNX file: " << onnxFile << std::endl; + } + } + if (onnxFile.empty()) { + std::cout << "No ONNX file provided." << std::endl; + return -1; + } + } else { + printf("Please set valide args: ./onnx-subgraph --onnx=xxx.onnx\n"); + return -1; + } + + + Graph graph; + auto g = graph.GetGraphFromOnnx(onnxFile); + std::unordered_map node_io_size; + Partition p; + Device target; + target.updateOnnxFile(onnxFile); + target.GetDeviceJson("./scripts/config.json"); + p.PartitionGraph(g, target, PartitionStrategy::SPILTE_NPU_STRUCTURE_FIRST, node_io_size); + + Py_Initialize(); + if (!Py_IsInitialized()) { + std::cout << "python init fail" << std::endl; + return 0; + } + PyRun_SimpleString("import sys"); + PyRun_SimpleString("sys.path.append('.')"); + Py_Finalize(); + + return 0; +} diff --git a/tools/onnx-subgraph/test_model_download.sh b/tools/onnx-subgraph/test_model_download.sh new file mode 100644 index 00000000000..d6597d2dd79 --- /dev/null +++ b/tools/onnx-subgraph/test_model_download.sh @@ -0,0 +1,16 @@ +pip install onnx onnxsim + +if [ ! -d "./models/" ];then + mkdir ./models/ + else + echo "./models path existing" +fi + +cd ./models +wget https://media.githubusercontent.com/media/onnx/models/refs/heads/main/Computer_Vision/resnext26ts_Opset16_timm/resnext26ts_Opset16.onnx --no-check-certificate +#wget https://media.githubusercontent.com/media/onnx/models/refs/heads/main/Natural_Language_Processing/xmod_Opset16_transformers/xmod_Opset16.onnx --no-check-certificate + +onnxsim resnext26ts_Opset16.onnx ../resnet-test.onnx +#onnxsim xmod_Opset16.onnx ../xmod-transformer-test.onnx + +cd ..