Skip to content
This repository was archived by the owner on Jun 27, 2023. It is now read-only.

support even bigger models #11

Merged
merged 1 commit into from
Apr 26, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,4 @@ venv.bak/

# bazel
bazel*
tf_trusted_custom_op/.bazelrc
18 changes: 9 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
### tf-trusted
### TF Trusted

tf-trusted allows you to run most tensorflow models inside of an SGX device. It leverages a Tensorflow custom op to send gRPC messages into the SGX device via Asylo where the model is then run by Tensorflow Lite.
TF Trusted allows you to run most Tensorflow models inside of an Intel SGX device. It leverages a Tensorflow custom op to send gRPC messages into the Intel SGX device via Asylo where the model is then run by Tensorflow Lite.

First clone this repo and follow the instructions [here](tf_trusted_custom_op/README.md) to build the required custom operation.

Expand All @@ -12,9 +12,9 @@ We're pinned to version v0.3.4 of the docker container for now.
$ docker pull gcr.io/asylo-framework/asylo:buildenv-v0.3.4
```

##### Build and Run tf-trusted
##### Build and Run TF Trusted

Here we use docker to build tf-trusted and then run it.
Here we use docker to build TF Trusted and then run it.

```
$ docker run -it --rm \
Expand All @@ -34,7 +34,7 @@ Run the client.
In another shell run the following with the correct options for the model you're using:

```
cd ../tf_trusted_custom_op
cd tf_trusted_custom_op
python model_run.py --model_file <location of protobuf model> \
--input_file <location of input file, npy format> \
--input_name <input placeholder node name> \
Expand All @@ -45,7 +45,7 @@ The input and output names are needed by the Tensorflow Lite converter to conver

You should see some array output!

### Running on SGX Device.
### Running on Intel SGX Device.

If running on a machine with a SGX Device you run the following to install the needed dependencies.

Expand All @@ -67,7 +67,7 @@ The aesmd service manages the SGX device.
service aesmd start
```

##### Build and Run tf-trusted
##### Build and Run TF Trusted

Now we can run a similar command as before. We just need to point the docker container to the SGX device, the aesmd socket and tell bazel inside the asylo docker container to use the SGX device.

Expand All @@ -87,15 +87,15 @@ $ docker run -it --rm --device=/dev/isgx \
In another shell run the following with the correct options for the model you're using:

```
cd ../tf_trusted_custom_op
cd tf_trusted_custom_op
python model_run.py --model_file <location of protobuf model> \
--input_file <location of input file, npy format> \
--input_name <input placeholder node name> \
--output_name <output node name>
```


##### Install tf-trusted custom op
##### Install TF Trusted custom op

To be able to run the `model_run.py` script from anywhere on your machine you can install it with pip:

Expand Down
1 change: 1 addition & 0 deletions tf_trusted/tf_trusted_enclave.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ asylo::Status GrpcServerEnclave::Initialize(
int selected_port;
builder.AddListeningPort(enclave_config.GetExtension(server_address),
::grpc::InsecureServerCredentials(), &selected_port);
builder.SetMaxReceiveMessageSize(INT_MAX);

// Add the translator service to the server.
builder.RegisterService(&modelService_);
Expand Down
6 changes: 5 additions & 1 deletion tf_trusted_custom_op/model_enclave_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,12 @@ class ModelLoadOp : public OpKernel {
const Tensor& model_name_tensor = context->input(0);
const Tensor& model_tensor = context->input(1);

grpc::ChannelArguments ch_args;
ch_args.SetMaxReceiveMessageSize(-1);

auto res = new ClientResource;
res->client = ModelClient(grpc::CreateChannel("localhost:50051", grpc::InsecureChannelCredentials()));
res->client = ModelClient(grpc::CreateCustomChannel("localhost:50051",
grpc::InsecureChannelCredentials(), ch_args));
auto resource_mgr = context->resource_manager();

Status s = resource_mgr->Create("connections", res_name, res);
Expand Down
12 changes: 11 additions & 1 deletion tf_trusted_custom_op/model_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
parser.add_argument('--benchmark', action='store_true', help='Run 100 timed inferences, results are stored in /tmp/tensorboard')
parser.add_argument('--batch_size', type=int, default='1', help='Batch size must match first dim of input file')
parser.add_argument('--model_name', type=str, default='model', help='Name your model!')
parser.add_argument('--input_shape', nargs='+', help='The input shape')
parser.add_argument('--from_file', action='store_true',
help='Tell the enclave to read from a file, file must exists on the enclave machine and already converted to tflite format')
config = parser.parse_args()
Expand Down Expand Up @@ -63,7 +64,15 @@ def get_input_shape(model_file, input_name):

# prepend import name gets added when calling import_graph_def
input = tf.get_default_session().graph.get_tensor_by_name('import/' + input_name + ":0")
shape = list(input.get_shape())

if config.input_shape is not None:
shape = config.input_shape
else:
try:
shape = list(input.get_shape())
except ValueError:
print("Error: Can't read shape from input try setting --input_shape instead")
exit()

# TODO i think this can just be inferred via the custom op
shape[0] = batch_size
Expand Down Expand Up @@ -91,6 +100,7 @@ def save_to_tensorboard(i, sess, run_metadata):
with open('{}/{}.ctr'.format("/tmp/tensorboard", session_tag), 'w') as f:
f.write(chrome_trace)


with tf.Session() as sess:
input_shape = get_input_shape(model_file, input_name)
output_shape, output_type = get_output_shape_and_type(model_file, output_name)
Expand Down