From 81e6c21412acd75c01c375f0640fa5e69d03a777 Mon Sep 17 00:00:00 2001 From: Eric Lai Date: Tue, 6 Jul 2021 17:16:31 +0800 Subject: [PATCH] add tutorial LayerList --- .../basic_tutorials/tutorial_LayerList.py | 34 +++++++++++++++++++ tensorlayer/layers/core/core_tensorflow.py | 5 ++- 2 files changed, 36 insertions(+), 3 deletions(-) create mode 100644 examples/basic_tutorials/tutorial_LayerList.py diff --git a/examples/basic_tutorials/tutorial_LayerList.py b/examples/basic_tutorials/tutorial_LayerList.py new file mode 100644 index 000000000..23d480fc7 --- /dev/null +++ b/examples/basic_tutorials/tutorial_LayerList.py @@ -0,0 +1,34 @@ +#! /usr/bin/python +# -*- coding: utf-8 -*- + +from tensorlayer.layers import Module, LayerList, Dense +import tensorlayer as tl + +d1 = Dense(n_units=800, act=tl.ReLU, in_channels=784, name='Dense1') +d2 = Dense(n_units=800, act=tl.ReLU, in_channels=800, name='Dense2') +d3 = Dense(n_units=10, act=tl.ReLU, in_channels=800, name='Dense3') + +layer_list = LayerList([d1, d2]) +# Inserts a given d2 before a given index in the list +layer_list.insert(1, d2) +layer_list.insert(2, d2) +# Appends d2 from a Python iterable to the end of the list. +layer_list.extend([d2]) +# Appends a given d3 to the end of the list. +layer_list.append(d3) + +print(layer_list) + +class model(Module): + def __init__(self): + super(model, self).__init__() + self._list = layer_list + def forward(self, inputs): + output = self._list[0](inputs) + for i in range(1, len(self._list)): + output = self._list[i](output) + return output + +net = model() +print(net) +print(net(tl.layers.Input((10, 784)))) \ No newline at end of file diff --git a/tensorlayer/layers/core/core_tensorflow.py b/tensorlayer/layers/core/core_tensorflow.py index f6a1dec5a..46bdc5fc1 100644 --- a/tensorlayer/layers/core/core_tensorflow.py +++ b/tensorlayer/layers/core/core_tensorflow.py @@ -678,10 +678,9 @@ class LayerList(Module): Examples: """ - def __init__(self, *args, **kwargs): + def __init__(self, args): super(LayerList, self).__init__() - if len(args) == 1: - self.extend(args[0]) + self.extend(args) def __getitem__(self, index): if isinstance(index, slice):