diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 00000000..4ff0198b --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,29 @@ +# Release 0.1.10 + +## Major Features and Improvements +* Add json-example serving input functions to TF.Transform. +* Add variance analyzer to tf.transform. + +## Bug Fixes and Other Changes +* Remove duplication in output of `tft.tfidf`. +* Ensure ngrams output dense_shape is greater than or equal to 0. +* Alters the behavior and interface of tensorflow_transform.mappers.ngrams. +* Use `apache-beam[gcp] >=2,<3` +* Making TF Parallelism runner-dependent. +* Fixes issue with csv serving input function. + +## Deprecations +* `tft.map` will be removed on version 0.2.0, see the `examples` directory for + instructions on how to use `tft.apply_function` instead (as needed). +* `tft.tfidf_weights` will be removed on version 0.2.0, use `tft.tfidf` instead. + +# Release 0.1.9 + +## Major Features and Improvements +* Refactor internals to remove Column and Statistic classes + +## Bug Fixes and Other Changes +* Remove collections from graph to avoid warnings +* Return float32 from tfidf_weights +* Update tensorflow_transform to use tf.saved_model APIs. +* Add default values on example proto coder. diff --git a/examples/census_example.py b/examples/census_example.py index c2cc78cd..be58a3f5 100755 --- a/examples/census_example.py +++ b/examples/census_example.py @@ -116,7 +116,8 @@ def preprocessing_fn(inputs): def convert_label(label): table = lookup.string_to_index_table_from_tensor(['>50K', '<=50K']) return table.lookup(label) - outputs[LABEL_COLUMN] = tft.map(convert_label, inputs[LABEL_COLUMN]) + outputs[LABEL_COLUMN] = tft.apply_function(convert_label, + inputs[LABEL_COLUMN]) return outputs diff --git a/examples/sentiment_example.py b/examples/sentiment_example.py index d9e87249..b9a6b8a1 100644 --- a/examples/sentiment_example.py +++ b/examples/sentiment_example.py @@ -140,13 +140,13 @@ def preprocessing_fn(inputs): """Preprocess input columns into transformed columns.""" review = inputs[REVIEW_COLUMN] - review_tokens = tft.map(lambda x: tf.string_split(x, DELIMITERS), - review) + review_tokens = tf.string_split(review, DELIMITERS) review_indices = tft.string_to_int(review_tokens, top_k=VOCAB_SIZE) # Add one for the oov bucket created by string_to_int. - review_weight = tft.tfidf_weights(review_indices, VOCAB_SIZE + 1) + review_bow_indices, review_weight = tft.tfidf(review_indices, + VOCAB_SIZE + 1) return { - REVIEW_COLUMN: review_indices, + REVIEW_COLUMN: review_bow_indices, REVIEW_WEIGHT: review_weight, LABEL_COLUMN: inputs[LABEL_COLUMN] } diff --git a/examples/simple_example.py b/examples/simple_example.py index 646ce5e2..ab810324 100644 --- a/examples/simple_example.py +++ b/examples/simple_example.py @@ -32,11 +32,10 @@ def preprocessing_fn(inputs): x = inputs['x'] y = inputs['y'] s = inputs['s'] - x_centered = tft.map(lambda x, mean: x - mean, x, tft.mean(x)) + x_centered = x - tft.mean(x) y_normalized = tft.scale_to_0_1(y) s_integerized = tft.string_to_int(s) - x_centered_times_y_normalized = tft.map(lambda x, y: x * y, - x_centered, y_normalized) + x_centered_times_y_normalized = (x_centered * y_normalized) return { 'x_centered': x_centered, 'y_normalized': y_normalized, diff --git a/getting_started.md b/getting_started.md index 99e21828..c1480bc2 100644 --- a/getting_started.md +++ b/getting_started.md @@ -11,35 +11,27 @@ aspects of the usage of tf.Transform. ## Defining a Preprocessing Function The most important concept of tf.Transform is the "preprocessing function". This -is a logical description of a transformation of a dataset. The dataset is -conceptualized as a dictionary of columns, and the preprocessing function is -defined by two basic mechanisms: - -1) Applying `tft.map`, which takes a user-defined function that accepts and -returns tensors. Such a function can use any TensorFlow operation to construct -the output tensors from the inputs. The remaining arguments of `tft.map` are the -columns that the function should be applied to. The number of columns provided -should equal the number of arguments to the user-defined function. Like the -Python `map` function, `tft.map` applies the user-provided function to the -elements in the columns specified. Each row is treated independently, and the -output is a column containing the results (but see the note on batching at the -end of this section). - -2) Applying any of the tf.Transform provided "analyzers". Analyzers are -functions that accept one or more `Column`s and return some summary statistic -for the input column or columns. A statistic is like a column except that it -only has a single value. An example of an analyzer is `tft.min` which computes -the minimum of a column. Currently tf.Transform provides a fixed set of -analyzers, but this will be extensible in future versions. - -In fact, `tft.map` can also accept statistics, which is how statistics are -incorporated into the user-defined pipeline. By combining analyzers and -`tft.map`, users can flexibly create pipelines for transforming their data. In -particular, users should define a "preprocessing function" which accepts and -returns columns. - -The following preprocessing function transforms each of three columns in -different ways, and combines two of the columns. +is a logical description of a transformation of a dataset. The preprocessing +function accepts and returns a dictionary of tensors (in this guide, "tensors" +generally means `Tensor`s or `SparseTensor`s). There are two kinds of functions +that can be used to define the preprocessing function: + +1) Any function that accepts and returns tensors. These will add TensorFlow +operations to the graph that transforms raw data into transformed data. + +2) Any of the tf.Transform provided "analyzers". Analyzers also accept and return +tensors, but unlike typical TensorFlow functions they don't add TF Operations +to the graph. Instead, they cause tf.Transform to compute a full pass operation +outside of TensorFlow, using the input tensor values over the full dataset to +generate a constant tensor that gets returned as the output. For example +`tft.min` computes the minimum of a tensor over the whole dataset. Currently +tf.Transform provides a fixed set of analyzers, but this will be extensible in +future versions. + +By combining analyzers and regular TensorFlow functions, users can flexibly +create pipelines for transforming their data. The following preprocessing +function transforms each of three features in different ways, and combines two +of the features. ``` import tensorflow as tf @@ -49,11 +41,10 @@ def preprocessing_fn(inputs): x = inputs['x'] y = inputs['y'] s = inputs['s'] - x_centered = tft.map(lambda x, mean: x - mean, x, tft.mean(x)) + x_centered = x - tft.mean(x) y_normalized = tft.scale_to_0_1(y) s_integerized = tft.string_to_int(s) - x_centered_times_y_normalized = tft.map(lambda x, y: x * y, - x_centered, y_normalized) + x_centered_times_y_normalized = x_centered * y_normalized return { 'x_centered': x_centered, 'y_normalized': y_normalized, @@ -62,32 +53,29 @@ def preprocessing_fn(inputs): } ``` -`x`, `y` and `s` are local variables that represent input columns, that are -declared for code brevity. The first new column to be constructed, `x_centered`, -is constructed by composing `tft.map` and `tft.mean`. `tft.mean(x)` returns a -statistic representing the mean of the column `x`. The lambda passed to -`tft.map` is simply subtraction, where the first argument is the column `x` and -the second is the statistic `tft.mean(x)`. Thus `x_centered` is the column `x` +`x`, `y` and `s` are `Tensor`s that represent input features. The first new +tensor to be constructed, `x_centered`, is constructed by applying `tft.mean` +to `x` and subtracting this from `x`. `tft.mean(x)` returns a tensor +representing the mean of the tensor `x`. Thus `x_centered` is the tensor `x` with the mean subtracted. -The second new column is `y_normalized`, created in a similar manner but using +The second new tensor is `y_normalized`, created in a similar manner but using the convenience method `tft.scale_to_0_1`. This method does something similar under the hood to what is done to compute `x_centered`, namely computing a max and min and using these to scale `y`. -The column `s_integerized` shows an example of string manipulation. In this +The tensor `s_integerized` shows an example of string manipulation. In this simple case we take a string and map it to an integer. This too uses a -convenience function, where the analyzer that is applied computes the unique -values taken by the column, and the map uses these values as a dictionary to -convert to an integer. +convenience function, `tft.string_to_int`. This function uses an analyzer to +compute the unique values taken by the input strings, and then uses TensorFlow +ops to convert the input strings to indices in the table of unique values. -The final column shows that it is possible to use `tft.map` not only to -manipulate a single column but also to combine columns. +The final column shows that it is possible to use tensorflow operations to +create new features by combining tensors. -Note that `Column`s are not themselves wrappers around data. Rather they are -placeholders used to construct a definition of the user's logical pipeline. In -order to apply such a pipeline to data, we rely on a concrete implementation of -the tf.Transform API. The Apache Beam implementation provides `PTransform`s that +The preprocessing function defines a pipeline of operations on a dataset. In +order to apply such a pipeline, we rely on a concrete implementation of the +tf.Transform API. The Apache Beam implementation provides `PTransform`s that apply a user's preprocessing function to data. The typical workflow of a tf.Transform user will be to construct a preprocessing function, and then incorporate this into a larger Beam pipeline, ultimately materializing the data @@ -100,13 +88,14 @@ tf.Transform is to provide the TensorFlow graph for preprocessing that can be incorporated into the serving graph (and optionally the training graph), batching is also an important concept in tf.Transform. -While it is not obvious from the example above, the user defined function passed -to `tft.map` will be passed tensors representing *batches*, not individual -instances, just as will happen during training and serving with TensorFlow. This -is only the case for inputs that are `Column`s, not `Statistic`s. Thus the -actual tensors used in the `tft.map` for `x_centered` are 1) a rank 1 tensor, -representing a batch of values from the column `x`, whose first dimension is the -batch dimension; and 2) a rank 0 tensor representing the mean of that column. +While it is not obvious from the example above, the user defined preprocessing +function will be passed tensors representing *batches*, not individual +instances, just as will happen during training and serving with TensorFlow. On +the other hand, analyzers perform a computation over the whole dataset and +return a single value, not a batch of values. Thus `x` is a `Tensor` of shape +`(batch_size,)` while `tft.mean(x)` is a `Tensor` of shape `()`. The +subtraction `x - tft.mean(x)` involves broadcasting where the value of +`tft.mean(x)` is subtracted from every element of the batch represented by `x`. ## The Canonical Beam Implementation @@ -322,17 +311,20 @@ def preprocessing_fn(inputs): def convert_label(label): table = lookup.string_to_index_table_from_tensor(['>50K', '<=50K']) return table.lookup(label) - outputs[LABEL_COLUMN] = tft.map(convert_label, inputs[LABEL_COLUMN]) + outputs[LABEL_COLUMN] = tft.apply_function( + convert_label, inputs[LABEL_COLUMN]) return outputs ``` -One difference from the previous example is that we convert the outputs from -scalars to single element vectors. This allows the data to be correctly read -during training. Also for the label column, we manually specify the mapping from -string to index so that ">50K" gets mapped to 0 and "<=50K" gets mapped to 1. -This is useful so that we know which index in the trained model corresponds to -which label. +One difference from the previous example is that for the label column, we +manually specify the mapping from string to index so that ">50K" gets mapped to +0 and "<=50K" gets mapped to 1. This is useful so that we know which index in +the trained model corresponds to which label. We cannot apply the function +`convert_label` directly to its arguments because `tf.Transform` needs to know +about the `Table` defined in `convert_label`. That is, `convert_label` is not +a pure function but involves table initialization. For such functions, we use +`tft.apply_function` to wrap the function application. The `raw_data` variable represents a `PCollection` containing data in the same format as the list `raw_data` from the previous example, and the use of the diff --git a/setup.py b/setup.py index 08ab688e..5ff3f9c1 100644 --- a/setup.py +++ b/setup.py @@ -17,14 +17,12 @@ from setuptools import setup # Tensorflow transform version. -__version__ = '0.1.9' +__version__ = '0.1.10' def _make_required_install_packages(): return [ - # Using >= for better integration tests. During release this is - # automatically changed to a ==. - 'apache-beam[gcp] == 0.6.0', + 'apache-beam[gcp]>=2,<3', ] diff --git a/tensorflow_transform/analyzers.py b/tensorflow_transform/analyzers.py index 8108d357..876753a1 100644 --- a/tensorflow_transform/analyzers.py +++ b/tensorflow_transform/analyzers.py @@ -45,20 +45,20 @@ class Analyzer(object): Args: inputs: The inputs to the analyzer. - output_shapes_and_dtype: List of pairs of (shape, dtype) for each output. + output_shapes_and_dtype: List of pairs of (dtype, shape) for each output. spec: A description of the computation to be done. Raises: ValueError: If the inputs are not all `Tensor`s. """ - def __init__(self, inputs, output_shapes_and_dtypes, spec): + def __init__(self, inputs, output_dtypes_and_shapes, spec): for tensor in inputs: if not isinstance(tensor, tf.Tensor): raise ValueError('Analyzers can only accept `Tensor`s as inputs') self._inputs = inputs - self._outputs = [tf.placeholder(shape, dtype) - for shape, dtype in output_shapes_and_dtypes] + self._outputs = [tf.placeholder(dtype, shape) + for dtype, shape in output_dtypes_and_shapes] self._spec = spec tf.add_to_collection(ANALYZER_COLLECTION, self) @@ -131,7 +131,7 @@ def min(x, reduce_instance_dims=True): # pylint: disable=redefined-builtin dimension and outputs a `Tensor` of the same shape as the input. Returns: - A `Tensor`. + A `Tensor`. Has the same type as `x`. """ return _numeric_combine(x, NumericCombineSpec.MIN, reduce_instance_dims) @@ -146,7 +146,7 @@ def max(x, reduce_instance_dims=True): # pylint: disable=redefined-builtin dimension and outputs a vector of the same shape as the output. Returns: - A `Tensor`. + A `Tensor`. Has the same type as `x`. """ return _numeric_combine(x, NumericCombineSpec.MAX, reduce_instance_dims) @@ -161,7 +161,7 @@ def sum(x, reduce_instance_dims=True): # pylint: disable=redefined-builtin dimension and outputs a vector of the same shape as the output. Returns: - A `Tensor`. + A `Tensor`. Has the same type as `x`. """ return _numeric_combine(x, NumericCombineSpec.SUM, reduce_instance_dims) @@ -176,7 +176,7 @@ def size(x, reduce_instance_dims=True): dimension and outputs a vector of the same shape as the output. Returns: - A `Tensor`. + A `Tensor`. Has the same type as `x`. """ with tf.name_scope('size'): # Note: Calling `sum` defined in this module, not the builtin. @@ -193,7 +193,10 @@ def mean(x, reduce_instance_dims=True): dimension and outputs a vector of the same shape as the output. Returns: - A `Tensor` containing the mean. + A `Tensor` containing the mean. If `x` is floating point, the mean will + have the same type as `x`. If `x` is integral, the output is cast to float32 + for int8 and int16 and float64 for int32 and int64 (similar to the behavior + of tf.truediv). """ with tf.name_scope('mean'): # Note: Calling `sum` defined in this module, not the builtin. @@ -201,6 +204,33 @@ def mean(x, reduce_instance_dims=True): sum(x, reduce_instance_dims), size(x, reduce_instance_dims)) +def var(x, reduce_instance_dims=True): + """Computes the variance of the values of a `Tensor` over the whole dataset. + + Uses the biased variance (0 delta degrees of freedom), as given by + (x - mean(x))**2 / length(x). + + Args: + x: A `Tensor`. + reduce_instance_dims: By default collapses the batch and instance dimensions + to arrive at a single scalar output. If False, only collapses the batch + dimension and outputs a vector of the same shape as the output. + + Returns: + A `Tensor` containing the variance. If `x` is floating point, the variance + will have the same type as `x`. If `x` is integral, the output is cast to + float32 for int8 and int16 and float64 for int32 and int64 (similar to the + behavior of tf.truediv). + """ + with tf.name_scope('var'): + # Note: Calling `mean`, `sum`, and `size` as defined in this module, not the + # builtins. + x_mean = mean(x, reduce_instance_dims) + # x_mean will be float32 or float64, depending on type of x. + squared_deviations = tf.square(tf.cast(x, x_mean.dtype) - x_mean) + return mean(squared_deviations, reduce_instance_dims) + + class UniquesSpec(object): """Operation to compute unique values.""" diff --git a/tensorflow_transform/beam/impl.py b/tensorflow_transform/beam/impl.py index 093ceed9..e2d3cf5e 100644 --- a/tensorflow_transform/beam/impl.py +++ b/tensorflow_transform/beam/impl.py @@ -83,6 +83,31 @@ def preprocessing_fn(inputs): _DEFAULT_DESIRED_BATCH_SIZE = 1000 +_DEFAULT_TENSORFLOW_CONFIG_BY_RUNNER = { + # We rely on Beam to manage concurrency, i.e. we expect it to run one + # session per CPU--so we don't want to proliferate TF threads. + # Nonetheless we provide 4 threads per session for TF ops, 2 inter- + # and 2 intra-thread. In many cases only 2 of these will be runnable + # at any given time. This approach oversubscribes a bit to make sure + # the CPUs are really saturated. + # + beam.runners.DataflowRunner: + tf.ConfigProto( + use_per_session_threads=True, + inter_op_parallelism_threads=2, + intra_op_parallelism_threads=2).SerializeToString(), + +} + + +def _maybe_deserialize_tf_config(serialized_tf_config): + if serialized_tf_config is None: + return None + + result = tf.ConfigProto() + result.ParseFromString(serialized_tf_config) + return result + class Context(object): """Context manager for tensorflow-transform. @@ -162,28 +187,19 @@ class _RunMetaGraphDoFn(beam.DoFn): exclude_outputs: A list of names of outputs to exclude. desired_batch_size: The desired number of instances to convert into a batch before feeding to Tensorflow. + serialized_tf_config: A serialized tf.ConfigProto to use in sessions. None + implies use Tensorflow defaults. """ class _GraphState(object): - def __init__(self, saved_model_dir, input_schema, output_schema): + def __init__(self, saved_model_dir, input_schema, output_schema, + tf_config): self.saved_model_dir = saved_model_dir self.graph = tf.Graph() - self.session = tf.Session( - graph=self.graph, - # We rely on Beam to manage concurrency, i.e. we expect it to run one - # session per CPU--so we don't want to proliferate TF threads. - # Nonetheless we provide 4 threads per session for TF ops, 2 inter- - # and 2 intra-thread. In many cases only 2 of these will be runnable - # at any given time. This approach oversubscribes a bit to make sure - # the CPUs are really saturated. - # - config=tf.ConfigProto( - use_per_session_threads=True, - inter_op_parallelism_threads=2, - intra_op_parallelism_threads=2)) + self.session = tf.Session(graph=self.graph, config=tf_config) with self.graph.as_default(): - with tf.Session(): + with tf.Session(config=tf_config): inputs, outputs = saved_transform_io.partially_apply_saved_transform( saved_model_dir, {}) self.session.run(tf.tables_initializer()) @@ -206,13 +222,16 @@ def __init__(self, saved_model_dir, input_schema, output_schema): def __init__(self, input_schema, output_schema, + serialized_tf_config, exclude_outputs=None, desired_batch_size=_DEFAULT_DESIRED_BATCH_SIZE): super(_RunMetaGraphDoFn, self).__init__() self._input_schema = input_schema self._output_schema = output_schema + self._serialized_tf_config = serialized_tf_config self._exclude_outputs = exclude_outputs self._desired_batch_size = desired_batch_size + self._batch = [] self._graph_state = None @@ -232,8 +251,13 @@ def _flush_batch(self): self._graph_state.inputs, self._input_schema, self._batch) del self._batch[:] - return self._graph_state.session.run( - self._graph_state.outputs, feed_dict=feed_dict) + try: + return self._graph_state.session.run( + self._graph_state.outputs, feed_dict=feed_dict) + except Exception as e: + tf.logging.error('%s while applying transform function for tensors %s' % + (e, self._graph_state.outputs)) + raise def process(self, element, saved_model_dir): """Runs the given graph to realize the output `Tensor` or `SparseTensor`s. @@ -254,8 +278,9 @@ def process(self, element, saved_model_dir): if (getattr(self._thread_local, 'graph_state', None) is None or self._thread_local.graph_state.saved_model_dir != saved_model_dir): start = datetime.datetime.now() + tf_config = _maybe_deserialize_tf_config(self._serialized_tf_config) self._thread_local.graph_state = self._GraphState( - saved_model_dir, self._input_schema, self._output_schema) + saved_model_dir, self._input_schema, self._output_schema, tf_config) self._graph_load_seconds_distribution.update( int((datetime.datetime.now() - start).total_seconds())) self._graph_state = self._thread_local.graph_state @@ -393,8 +418,10 @@ def expand(self, tensor_pcoll_mapping): # a temp dir. This makes the wrapper idempotent since any retry will # use a different temp dir. def replace_tensors_with_constant_values( - saved_model_dir, tensor_value_mapping): - with tf.Session() as session: + saved_model_dir, tensor_value_mapping, serialized_tf_config): + + tf_config = _maybe_deserialize_tf_config(serialized_tf_config) + with tf.Session(config=tf_config) as session: temp_dir = _make_unique_temp_dir(base_temp_dir) input_tensors, output_tensors = ( saved_transform_io.partially_apply_saved_transform( @@ -402,10 +429,14 @@ def replace_tensors_with_constant_values( saved_transform_io.write_saved_transform_from_session( session, input_tensors, output_tensors, temp_dir) return temp_dir + + serialized_tf_config = _DEFAULT_TENSORFLOW_CONFIG_BY_RUNNER.get( + self.pipeline.runner) return (transform_fn | 'ReplaceTensorsWithConstantValues' >> beam.Map( replace_tensors_with_constant_values, - tensor_value_mapping=tensor_value_mapping)) + tensor_value_mapping=tensor_value_mapping, + serialized_tf_config=serialized_tf_config)) class _ComputeTensorPcollMappingUpdate(beam.PTransform): """Create a mapping from `Tensor`s to PCollections. @@ -434,10 +465,12 @@ def expand(self, input_values_and_tensor_pcoll_mapping): >> _ReplaceTensorsWithConstants(self._saved_model_dir)) # Run the transform_fn. + serialized_tf_config = _DEFAULT_TENSORFLOW_CONFIG_BY_RUNNER.get( + self.pipeline.runner) analyzer_input_values = ( input_values | 'ComputeAnalyzerInputs' >> beam.ParDo( - _RunMetaGraphDoFn(input_schema, - self._analyzer_inputs_schema), + _RunMetaGraphDoFn(input_schema, self._analyzer_inputs_schema, + serialized_tf_config), saved_model_dir=beam.pvalue.AsSingleton(transform_fn))) # For each analyzer output, look up its input values (by tensor name) @@ -598,12 +631,16 @@ def convert_and_unbatch(batch_dict): return impl_helper.to_instance_dicts( impl_helper.make_output_dict(output_metadata.schema, batch_dict)) + serialized_tf_config = _DEFAULT_TENSORFLOW_CONFIG_BY_RUNNER.get( + self.pipeline.runner) output_instances = ( input_values | 'Transform' >> beam.ParDo( - _RunMetaGraphDoFn(input_metadata.schema, - output_metadata.schema, - exclude_outputs=self._exclude_outputs), + _RunMetaGraphDoFn( + input_metadata.schema, + output_metadata.schema, + serialized_tf_config, + exclude_outputs=self._exclude_outputs), saved_model_dir=beam.pvalue.AsSingleton(transform_fn)) | 'ConvertAndUnbatch' >> beam.FlatMap(convert_and_unbatch)) return (output_instances, output_metadata) diff --git a/tensorflow_transform/beam/impl_test.py b/tensorflow_transform/beam/impl_test.py index 817a14ec..1ee98874 100644 --- a/tensorflow_transform/beam/impl_test.py +++ b/tensorflow_transform/beam/impl_test.py @@ -77,6 +77,11 @@ def assertAnalyzeAndTransformResults( | beam_impl.AnalyzeAndTransformDataset(preprocessing_fn)) self.assertDataEqual(expected_data, transformed_data) + # Use extra assertEqual for schemas, since full metadata assertEqual error + # message is not conducive to debugging. + self.assertEqual( + expected_metadata.schema.column_schemas, + transformed_metadata.schema.column_schemas) self.assertEqual(expected_metadata, transformed_metadata) def testApplySavedModelSingleInput(self): @@ -583,7 +588,72 @@ def preprocessing_fn(inputs): self.assertTrue( 'output_min must be less than output_max' in context.exception) - def testNumericAnalyzersWithScalarInputs(self): + def testNumericAnalyzersWithScalarInputs_int64(self): + self.numericAnalyzersWithScalarInputs( + input_dtype=tf.int64, + output_dtypes={ + 'min': tf.int64, + 'max': tf.int64, + 'sum': tf.int64, + 'size': tf.int64, + 'mean': tf.float64, + 'var': tf.float64 + } + ) + + def testNumericAnalyzersWithScalarInputs_int32(self): + self.numericAnalyzersWithScalarInputs( + input_dtype=tf.int32, + output_dtypes={ + 'min': tf.int32, + 'max': tf.int32, + 'sum': tf.int32, + 'size': tf.int32, + 'mean': tf.float64, + 'var': tf.float64 + } + ) + + def testNumericAnalyzersWithScalarInputs_int16(self): + self.numericAnalyzersWithScalarInputs( + input_dtype=tf.int16, + output_dtypes={ + 'min': tf.int16, + 'max': tf.int16, + 'sum': tf.int16, + 'size': tf.int16, + 'mean': tf.float32, + 'var': tf.float32 + } + ) + + def testNumericAnalyzersWithScalarInputs_float64(self): + self.numericAnalyzersWithScalarInputs( + input_dtype=tf.float64, + output_dtypes={ + 'min': tf.float64, + 'max': tf.float64, + 'sum': tf.float64, + 'size': tf.float64, + 'mean': tf.float64, + 'var': tf.float64 + } + ) + + def testNumericAnalyzersWithScalarInputs_float32(self): + self.numericAnalyzersWithScalarInputs( + input_dtype=tf.float32, + output_dtypes={ + 'min': tf.float32, + 'max': tf.float32, + 'sum': tf.float32, + 'size': tf.float32, + 'mean': tf.float32, + 'var': tf.float32 + } + ) + + def numericAnalyzersWithScalarInputs(self, input_dtype, output_dtypes): def preprocessing_fn(inputs): def repeat(in_tensor, value): batch_size = tf.shape(in_tensor)[0] @@ -594,24 +664,31 @@ def repeat(in_tensor, value): 'max': repeat(inputs['a'], tft.max(inputs['a'])), 'sum': repeat(inputs['a'], tft.sum(inputs['a'])), 'size': repeat(inputs['a'], tft.size(inputs['a'])), - 'mean': repeat(inputs['a'], tft.mean(inputs['a'])) + 'mean': repeat(inputs['a'], tft.mean(inputs['a'])), + 'var': repeat(inputs['a'], tft.var(inputs['a'])) } input_data = [{'a': 4}, {'a': 1}] input_metadata = dataset_metadata.DatasetMetadata({ - 'a': sch.ColumnSchema(tf.int64, [], sch.FixedColumnRepresentation()) + 'a': sch.ColumnSchema(input_dtype, [], sch.FixedColumnRepresentation()) }) expected_data = [ - {'min': 1, 'max': 4, 'sum': 5, 'size': 2, 'mean': 2.5}, - {'min': 1, 'max': 4, 'sum': 5, 'size': 2, 'mean': 2.5} + {'min': 1, 'max': 4, 'sum': 5, 'size': 2, 'mean': 2.5, 'var': 2.25}, + {'min': 1, 'max': 4, 'sum': 5, 'size': 2, 'mean': 2.5, 'var': 2.25} ] expected_metadata = dataset_metadata.DatasetMetadata({ - 'min': sch.ColumnSchema(tf.int64, [], sch.FixedColumnRepresentation()), - 'max': sch.ColumnSchema(tf.int64, [], sch.FixedColumnRepresentation()), - 'sum': sch.ColumnSchema(tf.int64, [], sch.FixedColumnRepresentation()), - 'size': sch.ColumnSchema(tf.int64, [], sch.FixedColumnRepresentation()), - 'mean': sch.ColumnSchema(tf.float64, [], - sch.FixedColumnRepresentation()) + 'min': sch.ColumnSchema(output_dtypes['min'], [], + sch.FixedColumnRepresentation()), + 'max': sch.ColumnSchema(output_dtypes['max'], [], + sch.FixedColumnRepresentation()), + 'sum': sch.ColumnSchema(output_dtypes['sum'], [], + sch.FixedColumnRepresentation()), + 'size': sch.ColumnSchema(output_dtypes['size'], [], + sch.FixedColumnRepresentation()), + 'mean': sch.ColumnSchema(output_dtypes['mean'], [], + sch.FixedColumnRepresentation()), + 'var': sch.ColumnSchema(output_dtypes['var'], [], + sch.FixedColumnRepresentation()) }) self.assertAnalyzeAndTransformResults( input_data, input_metadata, preprocessing_fn, expected_data, @@ -638,7 +715,10 @@ def repeat(in_tensor, value): tft.size(inputs['a'], reduce_instance_dims=False)), 'mean': repeat(inputs['a'], - tft.mean(inputs['a'], reduce_instance_dims=False)) + tft.mean(inputs['a'], reduce_instance_dims=False)), + 'var': + repeat(inputs['a'], + tft.var(inputs['a'], reduce_instance_dims=False)) } input_data = [ @@ -653,13 +733,15 @@ def repeat(in_tensor, value): 'max': [8, 9, 10, 11], 'sum': [9, 11, 13, 15], 'size': [2, 2, 2, 2], - 'mean': [4.5, 5.5, 6.5, 7.5] + 'mean': [4.5, 5.5, 6.5, 7.5], + 'var': [12.25, 12.25, 12.25, 12.25] }, { 'min': [1, 2, 3, 4], 'max': [8, 9, 10, 11], 'sum': [9, 11, 13, 15], 'size': [2, 2, 2, 2], - 'mean': [4.5, 5.5, 6.5, 7.5] + 'mean': [4.5, 5.5, 6.5, 7.5], + 'var': [12.25, 12.25, 12.25, 12.25] }] expected_metadata = dataset_metadata.DatasetMetadata({ 'min': sch.ColumnSchema( @@ -671,6 +753,8 @@ def repeat(in_tensor, value): 'size': sch.ColumnSchema( tf.int64, [4], sch.FixedColumnRepresentation()), 'mean': sch.ColumnSchema( + tf.float64, [4], sch.FixedColumnRepresentation()), + 'var': sch.ColumnSchema( tf.float64, [4], sch.FixedColumnRepresentation()) }) self.assertAnalyzeAndTransformResults( @@ -694,7 +778,9 @@ def repeat(in_tensor, value): 'size': repeat(inputs['a'], tft.size(inputs['a'], reduce_instance_dims=False)), 'mean': repeat(inputs['a'], - tft.mean(inputs['a'], reduce_instance_dims=False)) + tft.mean(inputs['a'], reduce_instance_dims=False)), + 'var': repeat(inputs['a'], + tft.var(inputs['a'], reduce_instance_dims=False)) } input_data = [ @@ -708,13 +794,15 @@ def repeat(in_tensor, value): 'max': [[8, 9], [10, 11]], 'sum': [[9, 11], [13, 15]], 'size': [[2, 2], [2, 2]], - 'mean': [[4.5, 5.5], [6.5, 7.5]] + 'mean': [[4.5, 5.5], [6.5, 7.5]], + 'var': [[12.25, 12.25], [12.25, 12.25]] }, { 'min': [[1, 2], [3, 4]], 'max': [[8, 9], [10, 11]], 'sum': [[9, 11], [13, 15]], 'size': [[2, 2], [2, 2]], - 'mean': [[4.5, 5.5], [6.5, 7.5]] + 'mean': [[4.5, 5.5], [6.5, 7.5]], + 'var': [[12.25, 12.25], [12.25, 12.25]] }] expected_metadata = dataset_metadata.DatasetMetadata({ 'min': sch.ColumnSchema( @@ -726,6 +814,8 @@ def repeat(in_tensor, value): 'size': sch.ColumnSchema( tf.int64, [2, 2], sch.FixedColumnRepresentation()), 'mean': sch.ColumnSchema( + tf.float64, [2, 2], sch.FixedColumnRepresentation()), + 'var': sch.ColumnSchema( tf.float64, [2, 2], sch.FixedColumnRepresentation()) }) self.assertAnalyzeAndTransformResults( @@ -743,7 +833,8 @@ def repeat(in_tensor, value): 'max': repeat(inputs['a'], tft.max(inputs['a'])), 'sum': repeat(inputs['a'], tft.sum(inputs['a'])), 'size': repeat(inputs['a'], tft.size(inputs['a'])), - 'mean': repeat(inputs['a'], tft.mean(inputs['a'])) + 'mean': repeat(inputs['a'], tft.mean(inputs['a'])), + 'var': repeat(inputs['a'], tft.var(inputs['a'])) } input_data = [ @@ -754,8 +845,8 @@ def repeat(in_tensor, value): 'a': sch.ColumnSchema(tf.int64, [2, 2], sch.FixedColumnRepresentation()) }) expected_data = [ - {'min': 1, 'max': 7, 'sum': 32, 'size': 8, 'mean': 4.0}, - {'min': 1, 'max': 7, 'sum': 32, 'size': 8, 'mean': 4.0} + {'min': 1, 'max': 7, 'sum': 32, 'size': 8, 'mean': 4.0, 'var': 3.5}, + {'min': 1, 'max': 7, 'sum': 32, 'size': 8, 'mean': 4.0, 'var': 3.5} ] expected_metadata = dataset_metadata.DatasetMetadata({ 'min': sch.ColumnSchema(tf.int64, [], sch.FixedColumnRepresentation()), @@ -763,7 +854,9 @@ def repeat(in_tensor, value): 'sum': sch.ColumnSchema(tf.int64, [], sch.FixedColumnRepresentation()), 'size': sch.ColumnSchema(tf.int64, [], sch.FixedColumnRepresentation()), 'mean': sch.ColumnSchema(tf.float64, [], - sch.FixedColumnRepresentation()) + sch.FixedColumnRepresentation()), + 'var': sch.ColumnSchema(tf.float64, [], + sch.FixedColumnRepresentation()) }) self.assertAnalyzeAndTransformResults( input_data, input_metadata, preprocessing_fn, expected_data, @@ -806,11 +899,18 @@ def mean_fn(inputs): return {'mean': repeat(inputs['a'], tft.mean(inputs['a']))} _ = input_dataset | beam_impl.AnalyzeDataset(mean_fn) + with self.assertRaises(TypeError): + def var_fn(inputs): + return {'var': repeat(inputs['a'], tft.var(inputs['a']))} + _ = input_dataset | beam_impl.AnalyzeDataset(var_fn) + def testStringToTFIDF(self): def preprocessing_fn(inputs): inputs_as_ints = tft.string_to_int(tf.string_split(inputs['a'])) + out_index, out_values = tft.tfidf(inputs_as_ints, 6) return { - 'tf_idf': tft.tfidf_weights(inputs_as_ints, 6) + 'tf_idf': out_values, + 'index': out_index } input_data = [{'a': 'hello hello world'}, {'a': 'hello goodbye hello world'}, @@ -820,24 +920,93 @@ def preprocessing_fn(inputs): }) # IDFs - # hello = log(3/3) = 0 - # world = log(3/3) = 0 - # goodbye = log(3/2) = 0.4054651081 - # I = log(3/2) - # like = log(3/2) - # pie = log(3/2) - log_3_over_2 = 0.4054651081 + # hello = log(4/3) = 0.28768 + # world = log(4/3) + # goodbye = log(4/2) = 0.69314 + # I = log(4/2) + # like = log(4/2) + # pie = log(4/2) + log_4_over_2 = 0.69314718056 + log_4_over_3 = 0.28768207245 expected_transformed_data = [{ - 'tf_idf': [0, 0, 0] + 'tf_idf': [(2/3)*log_4_over_3, (1/3)*log_4_over_3], + 'index': [0, 2] }, { - 'tf_idf': [0, (1/4)*log_3_over_2, 0, 0] + 'tf_idf': [(2/4)*log_4_over_3, (1/4)*log_4_over_3, (1/4)*log_4_over_2], + 'index': [0, 2, 4] }, { - 'tf_idf': [(1/5)*log_3_over_2, (1/5)*log_3_over_2, (1/5)*log_3_over_2, - (1/5)*log_3_over_2, (1/5)*log_3_over_2] + 'tf_idf': [(3/5)*log_4_over_2, (1/5)*log_4_over_2, (1/5)*log_4_over_2], + 'index': [1, 3, 5] }] expected_transformed_schema = dataset_metadata.DatasetMetadata({ 'tf_idf': sch.ColumnSchema(tf.float32, [None], - sch.ListColumnRepresentation()) + sch.ListColumnRepresentation()), + 'index': sch.ColumnSchema(tf.int64, [None], + sch.ListColumnRepresentation()) + }) + self.assertAnalyzeAndTransformResults( + input_data, input_schema, preprocessing_fn, expected_transformed_data, + expected_transformed_schema) + + def testTFIDFNoData(self): + def preprocessing_fn(inputs): + inputs_as_ints = tft.string_to_int(tf.string_split(inputs['a'])) + out_index, out_values = tft.tfidf(inputs_as_ints, 6) + return { + 'tf_idf': out_values, + 'index': out_index + } + input_data = [{'a': ''}] + input_schema = dataset_metadata.DatasetMetadata({ + 'a': sch.ColumnSchema(tf.string, [], sch.FixedColumnRepresentation()) + }) + expected_transformed_data = [{'tf_idf': [], 'index': []}] + expected_transformed_schema = dataset_metadata.DatasetMetadata({ + 'tf_idf': sch.ColumnSchema(tf.float32, [None], + sch.ListColumnRepresentation()), + 'index': sch.ColumnSchema(tf.int64, [None], + sch.ListColumnRepresentation()) + }) + self.assertAnalyzeAndTransformResults( + input_data, input_schema, preprocessing_fn, expected_transformed_data, + expected_transformed_schema) + + def testStringToTFIDFEmptyDoc(self): + def preprocessing_fn(inputs): + inputs_as_ints = tft.string_to_int(tf.string_split(inputs['a'])) + out_index, out_values = tft.tfidf(inputs_as_ints, 6) + return { + 'tf_idf': out_values, + 'index': out_index + } + input_data = [{'a': 'hello hello world'}, + {'a': ''}, + {'a': 'hello goodbye hello world'}, + {'a': 'I like pie pie pie'}] + input_schema = dataset_metadata.DatasetMetadata({ + 'a': sch.ColumnSchema(tf.string, [], sch.FixedColumnRepresentation()) + }) + + log_5_over_2 = 0.91629073187 + log_5_over_3 = 0.51082562376 + expected_transformed_data = [{ + 'tf_idf': [(2/3)*log_5_over_3, (1/3)*log_5_over_3], + 'index': [0, 2] + }, { + 'tf_idf': [], + 'index': [] + }, { + 'tf_idf': [(2/4)*log_5_over_3, (1/4)*log_5_over_3, (1/4)*log_5_over_2], + 'index': [0, 2, 4] + }, { + 'tf_idf': [(3/5)*log_5_over_2, (1/5)*log_5_over_2, (1/5)*log_5_over_2], + 'index': [1, 3, 5] + }] + expected_transformed_schema = dataset_metadata.DatasetMetadata({ + 'tf_idf': sch.ColumnSchema(tf.float32, [None], + sch.ListColumnRepresentation()), + 'index': sch.ColumnSchema(tf.int64, [None], + sch.ListColumnRepresentation()) }) self.assertAnalyzeAndTransformResults( input_data, input_schema, preprocessing_fn, expected_transformed_data, @@ -845,7 +1014,40 @@ def preprocessing_fn(inputs): def testIntToTFIDF(self): def preprocessing_fn(inputs): - return {'tf_idf': tft.tfidf_weights(inputs['a'], 13)} + out_index, out_values = tft.tfidf(inputs['a'], 13) + return {'tf_idf': out_values, 'index': out_index} + input_data = [{'a': [2, 2, 0]}, + {'a': [2, 6, 2, 0]}, + {'a': [8, 10, 12, 12, 12]}, + ] + input_schema = dataset_metadata.DatasetMetadata({ + 'a': sch.ColumnSchema(tf.int64, [], sch.ListColumnRepresentation())}) + log_4_over_2 = 0.69314718056 + log_4_over_3 = 0.28768207245 + expected_data = [{ + 'tf_idf': [(1/3)*log_4_over_3, (2/3)*log_4_over_3], + 'index': [0, 2] + }, { + 'tf_idf': [(1/4)*log_4_over_3, (2/4)*log_4_over_3, (1/4)*log_4_over_2], + 'index': [0, 2, 6] + }, { + 'tf_idf': [(1/5)*log_4_over_2, (1/5)*log_4_over_2, (3/5)*log_4_over_2], + 'index': [8, 10, 12] + }] + expected_schema = dataset_metadata.DatasetMetadata({ + 'tf_idf': sch.ColumnSchema(tf.float32, [None], + sch.ListColumnRepresentation()), + 'index': sch.ColumnSchema(tf.int64, [None], + sch.ListColumnRepresentation()) + }) + self.assertAnalyzeAndTransformResults( + input_data, input_schema, preprocessing_fn, expected_data, + expected_schema) + + def testIntToTFIDFWithoutSmoothing(self): + def preprocessing_fn(inputs): + out_index, out_values = tft.tfidf(inputs['a'], 13, smooth=False) + return {'tf_idf': out_values, 'index': out_index} input_data = [{'a': [2, 2, 0]}, {'a': [2, 6, 2, 0]}, {'a': [8, 10, 12, 12, 12]}, @@ -853,17 +1055,22 @@ def preprocessing_fn(inputs): input_schema = dataset_metadata.DatasetMetadata({ 'a': sch.ColumnSchema(tf.int64, [], sch.ListColumnRepresentation())}) log_3_over_2 = 0.4054651081 + log_3 = 1.0986122886 expected_data = [{ - 'tf_idf': [0, 0, 0] + 'tf_idf': [(1/3)*log_3_over_2, (2/3)*log_3_over_2], + 'index': [0, 2] }, { - 'tf_idf': [0, (1/4)*log_3_over_2, 0, 0] + 'tf_idf': [(1/4)*log_3_over_2, (2/4)*log_3_over_2, (1/4)*log_3], + 'index': [0, 2, 6] }, { - 'tf_idf': [(1/5)*log_3_over_2, (1/5)*log_3_over_2, (1/5)*log_3_over_2, - (1/5)*log_3_over_2, (1/5)*log_3_over_2] + 'tf_idf': [(1/5)*log_3, (1/5)*log_3, (3/5)*log_3], + 'index': [8, 10, 12] }] expected_schema = dataset_metadata.DatasetMetadata({ 'tf_idf': sch.ColumnSchema(tf.float32, [None], - sch.ListColumnRepresentation()) + sch.ListColumnRepresentation()), + 'index': sch.ColumnSchema(tf.int64, [None], + sch.ListColumnRepresentation()) }) self.assertAnalyzeAndTransformResults( input_data, input_schema, preprocessing_fn, expected_data, @@ -874,8 +1081,11 @@ def testTFIDFWithOOV(self): def preprocessing_fn(inputs): inputs_as_ints = tft.string_to_int(tf.string_split(inputs['a']), top_k=test_vocab_size) + out_index, out_values = tft.tfidf(inputs_as_ints, + test_vocab_size+1) return { - 'tf_idf': tft.tfidf_weights(inputs_as_ints, test_vocab_size+1) + 'tf_idf': out_values, + 'index': out_index } input_data = [{'a': 'hello hello world'}, {'a': 'hello goodbye hello world'}, @@ -889,18 +1099,23 @@ def preprocessing_fn(inputs): # pie = log(3/2) = 0.4054651081 # world = log(3/3) = 0 # OOV - goodbye, I, like = log(3/3) - log_3_over_2 = 0.4054651081 + log_4_over_2 = 0.69314718056 + log_4_over_3 = 0.28768207245 expected_transformed_data = [{ - 'tf_idf': [0, 0, 0] + 'tf_idf': [(2/3)*log_4_over_3, (1/3)*log_4_over_3], + 'index': [0, 2] }, { - 'tf_idf': [0, 0, 0, 0] + 'tf_idf': [(2/4)*log_4_over_3, (1/4)*log_4_over_3, (1/4)*log_4_over_3], + 'index': [0, 2, 3] }, { - 'tf_idf': [0, 0, (1/5)*log_3_over_2, - (1/5)*log_3_over_2, (1/5)*log_3_over_2] + 'tf_idf': [(3/5)*log_4_over_2, (2/5)*log_4_over_3], + 'index': [1, 3] }] expected_transformed_schema = dataset_metadata.DatasetMetadata({ 'tf_idf': sch.ColumnSchema(tf.float32, [None], - sch.ListColumnRepresentation()) + sch.ListColumnRepresentation()), + 'index': sch.ColumnSchema(tf.int64, [None], + sch.ListColumnRepresentation()) }) self.assertAnalyzeAndTransformResults( input_data, input_schema, preprocessing_fn, expected_transformed_data, @@ -908,8 +1123,10 @@ def preprocessing_fn(inputs): def testTFIDFWithNegatives(self): def preprocessing_fn(inputs): + out_index, out_values = tft.tfidf(inputs['a'], 14) return { - 'tf_idf': tft.tfidf_weights(inputs['a'], 14) + 'tf_idf': out_values, + 'index': out_index } input_data = [{'a': [2, 2, -4]}, {'a': [2, 6, 2, -1]}, @@ -918,19 +1135,24 @@ def preprocessing_fn(inputs): input_schema = dataset_metadata.DatasetMetadata({ 'a': sch.ColumnSchema(tf.int64, [], sch.ListColumnRepresentation())}) - log_3_over_2 = 0.4054651081 + log_4_over_2 = 0.69314718056 + log_4_over_3 = 0.28768207245 # NOTE: -4 mod 14 = 10 expected_transformed_data = [{ - 'tf_idf': [0, 0, 0] + 'tf_idf': [(2/3)*log_4_over_3, (1/3)*log_4_over_3], + 'index': [2, 10] }, { - 'tf_idf': [0, (1/4)*log_3_over_2, 0, (1/4)*log_3_over_2] + 'tf_idf': [(2/4)*log_4_over_3, (1/4)*log_4_over_2, (1/4)*log_4_over_2], + 'index': [2, 6, 13] }, { - 'tf_idf': [(1/5)*log_3_over_2, 0, (1/5)*log_3_over_2, - (1/5)*log_3_over_2, (1/5)*log_3_over_2] + 'tf_idf': [(1/5)*log_4_over_2, (1/5)*log_4_over_3, (3/5)*log_4_over_2], + 'index': [8, 10, 12] }] expected_transformed_schema = dataset_metadata.DatasetMetadata({ 'tf_idf': sch.ColumnSchema(tf.float32, [None], - sch.ListColumnRepresentation()) + sch.ListColumnRepresentation()), + 'index': sch.ColumnSchema(tf.int64, [None], + sch.ListColumnRepresentation()) }) self.assertAnalyzeAndTransformResults( input_data, input_schema, preprocessing_fn, expected_transformed_data, @@ -1144,7 +1366,7 @@ def preprocessing_fn(inputs): return { 'index1': tft.string_to_int( - tft.map(tf.string_split, inputs['a']), + tf.string_split(inputs['a']), default_value=-99, top_k=1, num_oov_buckets=3) diff --git a/tensorflow_transform/mappers.py b/tensorflow_transform/mappers.py index 53499460..ba900f81 100644 --- a/tensorflow_transform/mappers.py +++ b/tensorflow_transform/mappers.py @@ -23,6 +23,7 @@ from tensorflow_transform import api from tensorflow.contrib import lookup +from tensorflow.python.util.deprecation import deprecated def scale_by_min_max(x, output_min=0.0, output_max=1.0): @@ -60,24 +61,41 @@ def scale_to_0_1(x): return scale_by_min_max(x, 0, 1) -def tfidf_weights(x, vocab_size): - """Maps the terms in x to their (1/doc_length) * inverse document frequency. +def tfidf(x, vocab_size, smooth=True): + """Maps the terms in x to their term frequency * inverse document frequency. + + The inverse document frequency of a term is calculated as + log((corpus size + 1) / (document frequency of term + 1)) by default. + + Example usage: + example strings [["I", "like", "pie", "pie", "pie"], ["yum", "yum", "pie]] + in: SparseTensor(indices=[[0, 0], [0, 1], [0, 2], [0, 3], [0, 4], + [1, 0], [1, 1], [1, 2]], + values=[1, 2, 0, 0, 0, 3, 3, 0]) + out: SparseTensor(indices=[[0, 0], [0, 1], [0, 2], [1, 0], [1, 1]], + values=[1, 2, 0, 3, 0]) + SparseTensor(indices=[[0, 0], [0, 1], [0, 2], [1, 0], [1, 1]], + values=[(1/5)*log(3/2), (1/5)*log(3/2), 0, + 0, (2/3)*log(3/2)]) + NOTE that the first doc's duplicate "pie" strings have been combined to + one output, as have the second doc's duplicate "yum" strings. Args: x: A `SparseTensor` representing int64 values (most likely that are the result of calling string_to_int on a tokenized string). vocab_size: An int - the count of vocab used to turn the string into int64s including any OOV buckets. + smooth: A bool indicating if the inverse document frequency should be + smoothed. If True, which is the default, then the idf is calculated as + log((corpus size + 1) / (document frequency of term + 1)). + Otherwise, the idf is + log((corpus size) / (document frequency of term)), which could + result in a divizion by zero error. Returns: - A `SparseTensor` where each int value is mapped to a double equal to - (1 if that term appears in that row, 0 otherwise / the number of terms in - that row) * the log of (the number of rows in `x` / (1 + the number of - rows in `x` where the term appears at least once)) - - NOTE: - This is intented to be used with the feature_column 'sum' combiner to arrive - at the true term frequncies. + Two `SparseTensor`s with indices [index_in_batch, index_in_bag_of_words]. + The first has values vocab_index, which is taken from input `x`. + The second has values tfidf_weight. """ def _to_vocab_range(x): @@ -87,97 +105,156 @@ def _to_vocab_range(x): values=tf.mod(x.values, vocab_size), dense_shape=x.dense_shape) - def _to_doc_contains_term(x): - """Creates a SparseTensor with 1s at every doc/term pair index. - - Args: - x : a SparseTensor of int64 representing string indices in vocab. + cleaned_input = _to_vocab_range(x) - Returns: - a SparseTensor with 1s at indices , - for every term/doc pair. - """ - # Construct intermediary sparse tensor with indices - # [, , ] and tf.ones values. - split_indices = tf.to_int64( - tf.split(x.indices, axis=1, num_or_size_splits=2)) - expanded_values = tf.to_int64(tf.expand_dims(x.values, 1)) - next_index = tf.concat( - [split_indices[0], split_indices[1], expanded_values], axis=1) + term_frequencies = _to_term_frequency(cleaned_input, vocab_size) - next_values = tf.ones_like(x.values) - vocab_size_as_tensor = tf.constant([vocab_size], dtype=tf.int64) - next_shape = tf.concat( - [x.dense_shape, vocab_size_as_tensor], 0) + count_docs_with_term_column = _count_docs_with_term(term_frequencies) + # Expand dims to get around the min_tensor_rank checks + sizes = tf.expand_dims(tf.shape(cleaned_input)[0], 0) + # [batch, vocab] - tfidf + tfidfs = _to_tfidf(term_frequencies, + analyzers.sum(count_docs_with_term_column, + reduce_instance_dims=False), + analyzers.sum(sizes), + smooth) + return _split_tfidfs_to_outputs(tfidfs) - next_tensor = tf.SparseTensor( - indices=tf.to_int64(next_index), - values=next_values, - dense_shape=next_shape) - # Take the intermediar tensor and reduce over the term_index_in_doc - # dimension. This produces a tensor with indices [, ] - # and values [count_of_term_in_doc] and shape batch x vocab_size - term_count_per_doc = tf.sparse_reduce_sum_sparse(next_tensor, 1) +def _split_tfidfs_to_outputs(tfidfs): + """Splits [batch, vocab]-weight into [batch, bow]-vocab & [batch, bow]-tfidf. - one_if_doc_contains_term = tf.SparseTensor( - indices=term_count_per_doc.indices, - values=tf.to_double(tf.greater(term_count_per_doc.values, 0)), - dense_shape=term_count_per_doc.dense_shape) + Args: + tfidfs: the `SparseTensor` output of _to_tfidf + Returns: + Two `SparseTensor`s with indices [index_in_batch, index_in_bag_of_words]. + The first has values vocab_index, which is taken from input `x`. + The second has values tfidf_weight. + """ + # Split tfidfs tensor into [batch, dummy] -> vocab & [batch, dummy] -> tfidf + # The "dummy" index counts from 0 to the number of unique tokens in the doc. + # So example doc ["I", "like", "pie", "pie", "pie"], with 3 unique tokens, + # will have "dummy" indices [0, 1, 2]. The particular dummy index that any + # token recieves is not important, only that the tfidf value and vocab index + # have the *same* dummy index, so that feature_column can apply the weight to + # the correct vocab item. + dummy_index = segment_indices(tfidfs.indices[:, 0]) + out_index = tf.concat( + [tf.expand_dims(tfidfs.indices[:, 0], 1), + tf.expand_dims(dummy_index, 1)], 1) + out_shape = [tfidfs.dense_shape[0], tf.reduce_max(dummy_index)+1] + + de_duped_indicies_out = tf.SparseTensor( + indices=out_index, + values=tfidfs.indices[:, 1], + dense_shape=out_shape) + de_duped_tfidf_out = tf.SparseTensor( + indices=out_index, + values=tfidfs.values, + dense_shape=out_shape) + return de_duped_indicies_out, de_duped_tfidf_out + + +def _to_term_frequency(x, vocab_size): + """Creates a SparseTensor of term frequency for every doc/term pair. - return one_if_doc_contains_term + Args: + x : a SparseTensor of int64 representing string indices in vocab. + vocab_size: An int - the count of vocab used to turn the string into int64s + including any OOV buckets. - def _to_tfidf(x, reduced_term_freq, corpus_size): - """Calculates the inverse document frequency of terms in the corpus. + Returns: + a SparseTensor with the count of times a term appears in a document at + indices , , + with size (num_docs_in_batch, vocab_size). + """ + # Construct intermediary sparse tensor with indices + # [, , ] and tf.ones values. + split_indices = tf.to_int64( + tf.split(x.indices, axis=1, num_or_size_splits=2)) + expanded_values = tf.to_int64(tf.expand_dims(x.values, 1)) + next_index = tf.concat( + [split_indices[0], split_indices[1], expanded_values], axis=1) + + next_values = tf.ones_like(x.values) + vocab_size_as_tensor = tf.constant([vocab_size], dtype=tf.int64) + next_shape = tf.concat( + [x.dense_shape, vocab_size_as_tensor], 0) + + next_tensor = tf.SparseTensor( + indices=tf.to_int64(next_index), + values=next_values, + dense_shape=next_shape) + + # Take the intermediar tensor and reduce over the term_index_in_doc + # dimension. This produces a tensor with indices [, ] + # and values [count_of_term_in_doc] and shape batch x vocab_size + term_count_per_doc = tf.sparse_reduce_sum_sparse(next_tensor, 1) + + dense_doc_sizes = tf.to_double(tf.sparse_reduce_sum(tf.SparseTensor( + indices=x.indices, + values=tf.ones_like(x.values), + dense_shape=x.dense_shape), 1)) + + gather_indices = term_count_per_doc.indices[:, 0] + gathered_doc_sizes = tf.gather(dense_doc_sizes, gather_indices) + + term_frequency = (tf.to_double(term_count_per_doc.values) / + tf.to_double(gathered_doc_sizes)) + return tf.SparseTensor( + indices=term_count_per_doc.indices, + values=term_frequency, + dense_shape=term_count_per_doc.dense_shape) - Args: - x : a `SparseTensor` of int64 representing string indices in vocab. - reduced_term_freq: A `Tensor` of shape (vocabSize,) that represents the - count of the number of documents with each term. - corpus_size: A scalar count of the number of documents in the corpus - Returns: - The tf*idf values - """ - # Add one to the reduced term freqnencies to avoid dividing by zero. - idf = tf.log(tf.to_double(corpus_size) / ( - 1.0 + tf.to_double(reduced_term_freq))) +def _to_tfidf(term_frequency, reduced_term_freq, corpus_size, smooth): + """Calculates the inverse document frequency of terms in the corpus. - dense_doc_sizes = tf.to_double(tf.sparse_reduce_sum(tf.SparseTensor( - indices=x.indices, - values=tf.ones_like(x.values), - dense_shape=x.dense_shape), 1)) + Args: + term_frequency: The `SparseTensor` output of _to_term_frequency. + reduced_term_freq: A `Tensor` of shape (vocabSize,) that represents the + count of the number of documents with each term. + corpus_size: A scalar count of the number of documents in the corpus. + smooth: A bool indicating if the idf value should be smoothed. See + tfidf_weights documentation for details. - # For every term in x, divide the idf by the doc size. - # The two gathers both result in shape - idf_over_doc_size = (tf.gather(idf, x.values) / - tf.gather(dense_doc_sizes, x.indices[:, 0])) + Returns: + A `SparseTensor` with indices=, , + values=term frequency * inverse document frequency, + and shape=(batch, vocab_size) + """ + # The idf tensor has shape (vocab_size,) + if smooth: + idf = tf.log((tf.to_double(corpus_size) + 1.0) / ( + 1.0 + tf.to_double(reduced_term_freq))) + else: + idf = tf.log(tf.to_double(corpus_size) / ( + tf.to_double(reduced_term_freq))) - return tf.SparseTensor( - indices=x.indices, - values=tf.to_float(idf_over_doc_size), - dense_shape=x.dense_shape) + gathered_idfs = tf.gather(tf.squeeze(idf), term_frequency.indices[:, 1]) + tfidf_values = tf.to_float(term_frequency.values) * tf.to_float(gathered_idfs) - cleaned_input = _to_vocab_range(x) + return tf.SparseTensor( + indices=term_frequency.indices, + values=tfidf_values, + dense_shape=term_frequency.dense_shape) - docs_with_terms = _to_doc_contains_term(cleaned_input) - def count_docs_with_term(term_frequency): - # Sum w/in batch. - count_of_doc_inter = tf.SparseTensor( - indices=term_frequency.indices, - values=tf.ones_like(term_frequency.values), - dense_shape=term_frequency.dense_shape) - out = tf.sparse_reduce_sum(count_of_doc_inter, axis=0) - return tf.expand_dims(out, 0) +def _count_docs_with_term(term_frequency): + """Computes the number of documents in a batch that contain each term. - count_docs_with_term_column = count_docs_with_term(docs_with_terms) - # Expand dims to get around the min_tensor_rank checks - sizes = tf.expand_dims(tf.shape(cleaned_input)[0], 0) - return _to_tfidf(cleaned_input, - analyzers.sum(count_docs_with_term_column, - reduce_instance_dims=False), - analyzers.sum(sizes)) + Args: + term_frequency: The `SparseTensor` output of _to_term_frequency. + Returns: + A `Tensor` of shape (vocab_size,) that contains the number of documents in + the batch that contain each term. + """ + count_of_doc_inter = tf.SparseTensor( + indices=term_frequency.indices, + values=tf.ones_like(term_frequency.values), + dense_shape=term_frequency.dense_shape) + out = tf.sparse_reduce_sum(count_of_doc_inter, axis=0) + return tf.expand_dims(out, 0) def string_to_int(x, default_value=-1, top_k=None, frequency_threshold=None, @@ -259,69 +336,79 @@ def segment_indices(segment_ids): segment_starts) -def ngrams(strings, ngram_range): +def ngrams(tokens, ngram_range, separator): """Create a `SparseTensor` of n-grams. - Given a vector of strings, return a sparse matrix containing the ngrams from - each string. Each row in the output `SparseTensor` contains the set of ngrams - from the corresponding element in the input `Tensor`. + Given a `SparseTensor` of tokens, returns a `SparseTensor` containing the + ngrams that can be constructed from each row. - The output ngrams including all whitespace and punctuation from the original - strings. + `separator` is inserted between each pair of tokens, so " " would be an + appropriate choice if the tokens are words, while "" would be an appropriate + choice if they are characters. Example: - strings = ['ab: c', 'wxy.'] + `tokens` is a `SparseTensor` with + + indices = [[0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2], [1, 3]] + values = ['One', 'was', 'Johnny', 'Two', 'was', 'a', 'rat'] + dense_shape = [2, 4] + + If we set ngrams_range = (1,3) + separator = ' ' output is a `SparseTensor` with - indices = [[0, 0], [0, 1], ..., [0, 11], [1, 0], [1, 1], ..., [1, 8]] - values = ['a', 'ab', 'ab:', 'b', 'b:', 'b: ', ':', ': ', ': c', ' ', ' c', - 'c', 'w', 'wx', 'wxy', 'x', 'xy', 'xy.', 'y', 'y.', '.'] - dense_shape = [2, 12] + indices = [[0, 0], [0, 1], [0, 2], ..., [1, 6], [1, 7], [1, 8]] + values = ['One', 'One was', 'One was Johnny', 'was', 'was Johnny', 'Johnny', + 'Two', 'Two was', 'Two was a', 'was', 'was a', 'was a rat', 'a', + 'a rat', 'rat'] + dense_shape = [2, 9] Args: - strings: A `Tensor` of strings with shape (batch_size,). + tokens: a two-dimensional`SparseTensor` of dtype `tf.string` containing + tokens that will be used to construct ngrams. ngram_range: A pair with the range (inclusive) of ngram sizes to return. + separator: a string that will be inserted between tokens when ngrams are + constructed. Returns: - A `SparseTensor` containing all ngrams from each element of the input. + A `SparseTensor` containing all ngrams from each row of the input. Raises: ValueError: if ngram_range[0] < 1 or ngram_range[1] < ngram_range[0] """ - # This function is implemented as follows. First we split the input. If the - # input is ['abcd', 'q', 'xyz'] then the split operation returns a - # SparseTensor with + # This function is implemented as follows. Assume we start with the following + # `SparseTensor`: # # indices=[[0, 0], [0, 1], [0, 2], [0, 3], [1, 0], [2, 0], [2, 1], [2, 2]] # values=['a', 'b', 'c', 'd', 'q', 'x', 'y', 'z'] # dense_shape=[3, 4] # - # We then create shifts of the values and first column of indices, buffering - # to avoid overruning the end of the array, so the shifted values (if we are - # creating ngrams up to size 3) are + # First we then create shifts of the values and first column of indices, + # buffering to avoid overruning the end of the array, so the shifted values + # (if we are ngrams up to size 3) are # # shifted_batch_indices[0]=[0, 0, 0, 0, 1, 2, 2, 2] - # shifted_chars[0]=['a', 'b', 'c', 'd', 'q', 'x', 'y', 'z'] + # shifted_tokens[0]=['a', 'b', 'c', 'd', 'q', 'x', 'y', 'z'] # # shifted_batch_indices[1]=[0, 0, 0, 1, 2, 2, 2, -1] - # shifted_chars[1]=['b', 'c', 'd', 'q', 'x', 'y', 'z', ''] + # shifted_tokens[1]=['b', 'c', 'd', 'q', 'x', 'y', 'z', ''] # # shifted_batch_indices[2]=[0, 0, 1, 2, 2, 2, -1, -1] - # shifted_chars[2]=['c', 'd', 'q', 'x', 'y', 'z', '', ''] + # shifted_tokens[2]=['c', 'd', 'q', 'x', 'y', 'z', '', ''] # # These shifted ngrams are used to create the ngrams as follows. We use - # tf.string_join to join shifted_chars[:k] to create k-grams. The batch that - # the first of these belonged to is given by shifted_batch_indices[0]. - # However some of these will cross the boundaries between 'batches' and so - # we we create a boolean mask which is True when shifted_indices[:k] are all - # equal. + # tf.string_join to join shifted_tokens[:k] to create k-grams. The `separator` + # string is inserted between each pair of tokens in the k-gram. + # The batch that the first of these belonged to is given by + # shifted_batch_indices[0]. However some of these will cross the boundaries + # between 'batches' and so we we create a boolean mask which is True when + # shifted_indices[:k] are all equal. # # This results in tensors of ngrams, their batch indices and a boolean mask, # which we then use to construct the output SparseTensor. - chars = tf.string_split(strings, delimiter='') if ngram_range[0] < 1 or ngram_range[1] < ngram_range[0]: raise ValueError('Invalid ngram_range: %r' % (ngram_range,)) @@ -333,18 +420,18 @@ def _sliding_windows(values, num_shifts, fill_value): for i in range(num_shifts)] shifted_batch_indices = _sliding_windows( - chars.indices[:, 0], ngram_range[1] + 1, tf.constant(-1, dtype=tf.int64)) - shifted_chars = _sliding_windows(chars.values, ngram_range[1] + 1, '') + tokens.indices[:, 0], ngram_range[1] + 1, tf.constant(-1, dtype=tf.int64)) + shifted_tokens = _sliding_windows(tokens.values, ngram_range[1] + 1, '') # Construct a tensor of the form # [['a', 'ab, 'abc'], ['b', 'bcd', cde'], ...] def _string_join(tensors): if tensors: - return tf.string_join(tensors) + return tf.string_join(tensors, separator=separator) else: return - ngrams_array = [_string_join(shifted_chars[:k]) + ngrams_array = [_string_join(shifted_tokens[:k]) for k in range(ngram_range[0], ngram_range[1] + 1)] ngrams_tensor = tf.stack(ngrams_array, 1) @@ -358,7 +445,7 @@ def _string_join(tensors): # Construct a tensor with the batch that each ngram in ngram_tensor belongs # to. - batch_indices = tf.tile(tf.expand_dims(chars.indices[:, 0], 1), + batch_indices = tf.tile(tf.expand_dims(tokens.indices[:, 0], 1), [1, ngram_range[1] + 1 - ngram_range[0]]) # Apply the boolean mask and construct a SparseTensor with the given indices @@ -367,8 +454,174 @@ def _string_join(tensors): batch_indices = tf.boolean_mask(batch_indices, valid_ngram) ngrams_tensor = tf.boolean_mask(ngrams_tensor, valid_ngram) instance_indices = segment_indices(batch_indices) + dense_shape_second_dim = tf.maximum(tf.reduce_max(instance_indices), -1) + 1 return tf.SparseTensor( - tf.stack([batch_indices, instance_indices], 1), - ngrams_tensor, - tf.stack([tf.size(strings, out_type=tf.int64), - tf.reduce_max(instance_indices) + 1], 0)) + indices=tf.stack([batch_indices, instance_indices], 1), + values=ngrams_tensor, + dense_shape=tf.stack( + [tokens.dense_shape[0], dense_shape_second_dim])) + + +def hash_strings(strings, hash_buckets, key=None): + """Hash strings into buckets. + + Args: + strings: a `Tensor` or `SparseTensor` of dtype `tf.string`. + hash_buckets: the number of hash buckets. + key: optional. An array of two Python `uint64`. If passed, output will be + a deterministic function of `strings` and `key`. Note that hashing will be + slower if this value is specified. + + Returns: + A `Tensor` or `SparseTensor` of dtype `tf.int64` with the same shape as the + input `strings`. + + Raises: + TypeError: if `strings` is not a `Tensor` or `SparseTensor` of dtype + `tf.string`. + """ + if (not isinstance(strings, (tf.Tensor, + tf.SparseTensor))) or strings.dtype != tf.string: + raise TypeError( + 'Input to hash_strings must be a Tensor or SparseTensor of dtype ' + 'string; got {}'. + format(strings.dtype)) + if isinstance(strings, tf.SparseTensor): + return tf.SparseTensor(indices=strings.indices, + values=hash_strings( + strings.values, hash_buckets, key), + dense_shape=strings.dense_shape) + if key is None: + return tf.string_to_hash_bucket_fast( + strings, hash_buckets, name='hash_strings') + return tf.string_to_hash_bucket_strong( + strings, hash_buckets, key, name='hash_strings') + + +############################################################################## +### ### +### DEPRECATED ### +### ### +############################################################################## + + +@deprecated('2017-08-25', + 'Use tfidf() instead.') +def tfidf_weights(x, vocab_size): + """Maps the terms in x to their (1/doc_length) * inverse document frequency. + + Args: + x: A `SparseTensor` representing int64 values (most likely that are the + result of calling string_to_int on a tokenized string). + vocab_size: An int - the count of vocab used to turn the string into int64s + including any OOV buckets. + + Returns: + A `SparseTensor` where each int value is mapped to a double equal to + (1 if that term appears in that row, 0 otherwise / the number of terms in + that row) * the log of (the number of rows in `x` / (1 + the number of + rows in `x` where the term appears at least once)) + + NOTE: + This is intented to be used with the feature_column 'sum' combiner to arrive + at the true term frequncies. + """ + + def _to_vocab_range(x): + """Enforces that the vocab_ids in x are positive.""" + return tf.SparseTensor( + indices=x.indices, + values=tf.mod(x.values, vocab_size), + dense_shape=x.dense_shape) + + def _to_doc_contains_term(x): + """Creates a SparseTensor with 1s at every doc/term pair index. + + Args: + x : a SparseTensor of int64 representing string indices in vocab. + + Returns: + a SparseTensor with 1s at indices , + for every term/doc pair. + """ + # Construct intermediary sparse tensor with indices + # [, , ] and tf.ones values. + split_indices = tf.to_int64( + tf.split(x.indices, axis=1, num_or_size_splits=2)) + expanded_values = tf.to_int64(tf.expand_dims(x.values, 1)) + next_index = tf.concat( + [split_indices[0], split_indices[1], expanded_values], axis=1) + + next_values = tf.ones_like(x.values) + vocab_size_as_tensor = tf.constant([vocab_size], dtype=tf.int64) + next_shape = tf.concat( + [x.dense_shape, vocab_size_as_tensor], 0) + + next_tensor = tf.SparseTensor( + indices=tf.to_int64(next_index), + values=next_values, + dense_shape=next_shape) + + # Take the intermediar tensor and reduce over the term_index_in_doc + # dimension. This produces a tensor with indices [, ] + # and values [count_of_term_in_doc] and shape batch x vocab_size + term_count_per_doc = tf.sparse_reduce_sum_sparse(next_tensor, 1) + + one_if_doc_contains_term = tf.SparseTensor( + indices=term_count_per_doc.indices, + values=tf.to_double(tf.greater(term_count_per_doc.values, 0)), + dense_shape=term_count_per_doc.dense_shape) + + return one_if_doc_contains_term + + def _to_idf_over_doc_size(x, reduced_term_freq, corpus_size): + """Calculates the inverse document frequency of terms in the corpus. + + Args: + x : a `SparseTensor` of int64 representing string indices in vocab. + reduced_term_freq: A `Tensor` of shape (vocabSize,) that represents the + count of the number of documents with each term. + corpus_size: A scalar count of the number of documents in the corpus + + Returns: + The tf*idf values + """ + # Add one to the reduced term freqnencies to avoid dividing by zero. + idf = tf.log(tf.to_double(corpus_size) / ( + 1.0 + tf.to_double(reduced_term_freq))) + + dense_doc_sizes = tf.to_double(tf.sparse_reduce_sum(tf.SparseTensor( + indices=x.indices, + values=tf.ones_like(x.values), + dense_shape=x.dense_shape), 1)) + + # For every term in x, divide the idf by the doc size. + # The two gathers both result in shape + idf_over_doc_size = (tf.gather(idf, x.values) / + tf.gather(dense_doc_sizes, x.indices[:, 0])) + + return tf.SparseTensor( + indices=x.indices, + values=tf.to_float(idf_over_doc_size), + dense_shape=x.dense_shape) + + cleaned_input = _to_vocab_range(x) + + docs_with_terms = _to_doc_contains_term(cleaned_input) + + def count_docs_with_term(term_frequency): + # Sum w/in batch. + count_of_doc_inter = tf.SparseTensor( + indices=term_frequency.indices, + values=tf.ones_like(term_frequency.values), + dense_shape=term_frequency.dense_shape) + out = tf.sparse_reduce_sum(count_of_doc_inter, axis=0) + return tf.expand_dims(out, 0) + + count_docs_with_term_column = count_docs_with_term(docs_with_terms) + # Expand dims to get around the min_tensor_rank checks + sizes = tf.expand_dims(tf.shape(cleaned_input)[0], 0) + return _to_idf_over_doc_size(cleaned_input, + analyzers.sum(count_docs_with_term_column, + reduce_instance_dims=False), + analyzers.sum(sizes)) diff --git a/tensorflow_transform/mappers_test.py b/tensorflow_transform/mappers_test.py index 6c611a70..38d704aa 100644 --- a/tensorflow_transform/mappers_test.py +++ b/tensorflow_transform/mappers_test.py @@ -27,6 +27,18 @@ class MappersTest(test_util.TensorFlowTestCase): + def assertSparseOutput(self, expected_indices, expected_values, + expected_shape, actual_sparse_tensor, close_values): + with tf.Session() as sess: + sess.run(tf.tables_initializer()) + actual = actual_sparse_tensor.eval() + self.assertAllEqual(expected_indices, actual.indices) + self.assertAllEqual(expected_shape, actual.dense_shape) + if close_values: + self.assertAllClose(expected_values, actual.values) + else: + self.assertAllEqual(expected_values, actual.values) + def testSegmentIndices(self): with tf.Session(): self.assertAllEqual( @@ -37,54 +49,244 @@ def testSegmentIndices(self): mappers.segment_indices(tf.constant([], tf.int64)).eval(), []) - def testNGrams(self): - output_tensor = mappers.ngrams( - tf.constant(['abc', 'def', 'fghijklm', 'z', '']), (1, 5)) + def testSegmentIndicesSkipOne(self): + input_tensor = tf.constant([0, 0, 2, 2]) + with tf.Session(): + self.assertAllEqual([0, 1, 0, 1], + mappers.segment_indices(input_tensor).eval()) + + def testNGramsEmpty(self): + output_tensor = mappers.ngrams(tf.string_split(tf.constant([''])), + (1, 5), '') with tf.Session(): output = output_tensor.eval() - self.assertAllEqual( - output.indices, - [[0, 0], [0, 1], [0, 2], [0, 3], [0, 4], [0, 5], - [1, 0], [1, 1], [1, 2], [1, 3], [1, 4], [1, 5], - [2, 0], [2, 1], [2, 2], [2, 3], [2, 4], [2, 5], [2, 6], [2, 7], - [2, 8], [2, 9], [2, 10], [2, 11], [2, 12], [2, 13], [2, 14], [2, 15], - [2, 16], [2, 17], [2, 18], [2, 19], [2, 20], [2, 21], [2, 22], - [2, 23], [2, 24], [2, 25], [2, 26], [2, 27], [2, 28], [2, 29], - [3, 0]]) - self.assertAllEqual(output.values, [ - 'a', 'ab', 'abc', 'b', 'bc', 'c', - 'd', 'de', 'def', 'e', 'ef', 'f', - 'f', 'fg', 'fgh', 'fghi', 'fghij', 'g', 'gh', 'ghi', 'ghij', 'ghijk', - 'h', 'hi', 'hij', 'hijk', 'hijkl', 'i', 'ij', 'ijk', 'ijkl', 'ijklm', - 'j', 'jk', 'jkl', 'jklm', 'k', 'kl', 'klm', 'l', 'lm', 'm', - 'z']) - self.assertAllEqual(output.dense_shape, [5, 30]) + self.assertEqual((0, 2), output.indices.shape) + self.assertAllEqual([1, 0], output.dense_shape) + self.assertEqual(0, len(output.values)) + + def testNGrams(self): + string_tensor = tf.constant(['abc', 'def', 'fghijklm', 'z', '']) + tokenized_tensor = tf.string_split(string_tensor, delimiter='') + output_tensor = mappers.ngrams( + tokens=tokenized_tensor, + ngram_range=(1, 5), + separator='') + self.assertSparseOutput( + expected_indices=[ + [0, 0], [0, 1], [0, 2], [0, 3], [0, 4], [0, 5], + [1, 0], [1, 1], [1, 2], [1, 3], [1, 4], [1, 5], + [2, 0], [2, 1], [2, 2], [2, 3], [2, 4], [2, 5], [2, 6], [2, 7], + [2, 8], [2, 9], [2, 10], [2, 11], [2, 12], [2, 13], [2, 14], + [2, 15], [2, 16], [2, 17], [2, 18], [2, 19], [2, 20], [2, 21], + [2, 22], [2, 23], [2, 24], [2, 25], [2, 26], [2, 27], [2, 28], + [2, 29], [3, 0]], + expected_values=[ + 'a', 'ab', 'abc', 'b', 'bc', 'c', + 'd', 'de', 'def', 'e', 'ef', 'f', + 'f', 'fg', 'fgh', 'fghi', 'fghij', 'g', 'gh', 'ghi', 'ghij', + 'ghijk', 'h', 'hi', 'hij', 'hijk', 'hijkl', 'i', 'ij', 'ijk', + 'ijkl', 'ijklm', 'j', 'jk', 'jkl', 'jklm', 'k', 'kl', 'klm', 'l', + 'lm', 'm', 'z'], + expected_shape=[5, 30], + actual_sparse_tensor=output_tensor, + close_values=False) def testNGramsMinSizeNotOne(self): + string_tensor = tf.constant(['abc', 'def', 'fghijklm', 'z', '']) + tokenized_tensor = tf.string_split(string_tensor, delimiter='') output_tensor = mappers.ngrams( - tf.constant(['abc', 'def', 'fghijklm', 'z', '']), (2, 5)) + tokens=tokenized_tensor, + ngram_range=(2, 5), + separator='') + self.assertSparseOutput( + expected_indices=[ + [0, 0], [0, 1], [0, 2], + [1, 0], [1, 1], [1, 2], + [2, 0], [2, 1], [2, 2], [2, 3], [2, 4], [2, 5], [2, 6], [2, 7], + [2, 8], [2, 9], [2, 10], [2, 11], [2, 12], [2, 13], [2, 14], + [2, 15], [2, 16], [2, 17], [2, 18], [2, 19], [2, 20], [2, 21]], + expected_values=[ + 'ab', 'abc', 'bc', + 'de', 'def', 'ef', + 'fg', 'fgh', 'fghi', 'fghij', 'gh', 'ghi', 'ghij', 'ghijk', + 'hi', 'hij', 'hijk', 'hijkl', 'ij', 'ijk', 'ijkl', 'ijklm', + 'jk', 'jkl', 'jklm', 'kl', 'klm', 'lm'], + expected_shape=[5, 22], + actual_sparse_tensor=output_tensor, + close_values=False) + + def testNGramsWithSpaceSeparator(self): + string_tensor = tf.constant(['One was Johnny', 'Two was a rat']) + tokenized_tensor = tf.string_split(string_tensor, delimiter=' ') + output_tensor = mappers.ngrams( + tokens=tokenized_tensor, + ngram_range=(1, 2), + separator=' ') with tf.Session(): output = output_tensor.eval() self.assertAllEqual( output.indices, - [[0, 0], [0, 1], [0, 2], - [1, 0], [1, 1], [1, 2], - [2, 0], [2, 1], [2, 2], [2, 3], [2, 4], [2, 5], [2, 6], [2, 7], - [2, 8], [2, 9], [2, 10], [2, 11], [2, 12], [2, 13], [2, 14], [2, 15], - [2, 16], [2, 17], [2, 18], [2, 19], [2, 20], [2, 21]]) + [[0, 0], [0, 1], [0, 2], [0, 3], [0, 4], + [1, 0], [1, 1], [1, 2], [1, 3], [1, 4], [1, 5], [1, 6]]) self.assertAllEqual(output.values, [ - 'ab', 'abc', 'bc', - 'de', 'def', 'ef', - 'fg', 'fgh', 'fghi', 'fghij', 'gh', 'ghi', 'ghij', 'ghijk', - 'hi', 'hij', 'hijk', 'hijkl', 'ij', 'ijk', 'ijkl', 'ijklm', - 'jk', 'jkl', 'jklm', 'kl', 'klm', 'lm']) - self.assertAllEqual(output.dense_shape, [5, 22]) + 'One', 'One was', 'was', 'was Johnny', 'Johnny', + 'Two', 'Two was', 'was', 'was a', 'a', 'a rat', 'rat']) + self.assertAllEqual(output.dense_shape, [2, 7]) def testNGramsBadSizes(self): + string_tensor = tf.constant(['abc', 'def', 'fghijklm', 'z', '']) + tokenized_tensor = tf.string_split(string_tensor, delimiter='') with self.assertRaisesRegexp(ValueError, 'Invalid ngram_range'): - mappers.ngrams(tf.constant(['abc', 'def', 'fghijklm', 'z', '']), (0, 5)) + mappers.ngrams(tokenized_tensor, (0, 5), separator='') with self.assertRaisesRegexp(ValueError, 'Invalid ngram_range'): - mappers.ngrams(tf.constant(['abc', 'def', 'fghijklm', 'z', '']), (6, 5)) + mappers.ngrams(tokenized_tensor, (6, 5), separator='') + + def testTermFrequency(self): + input_tensor = tf.SparseTensor( + [[0, 0], [0, 1], [0, 2], [0, 3], [0, 4], [1, 0], [1, 1]], + [1, 2, 0, 0, 0, 3, 0], + [2, 5]) + self.assertSparseOutput( + expected_indices=[[0, 0], [0, 1], [0, 2], [1, 0], [1, 3]], + expected_values=[(3/5), (1/5), (1/5), (1/2), (1/2)], + expected_shape=[2, 4], + actual_sparse_tensor=mappers._to_term_frequency(input_tensor, 4), + close_values=True) + + def testTermFrequencyUnusedTerm(self): + input_tensor = tf.SparseTensor( + [[0, 0], [0, 1], [0, 2], [0, 3], [0, 4], [1, 0], [1, 1]], + [4, 2, 0, 0, 0, 3, 0], + [2, 5]) + self.assertSparseOutput( + expected_indices=[[0, 0], [0, 2], [0, 4], [1, 0], [1, 3]], + expected_values=[(3/5), (1/5), (1/5), (1/2), (1/2)], + expected_shape=[2, 5], + actual_sparse_tensor=mappers._to_term_frequency(input_tensor, 5), + close_values=True) + + def testCountDocsWithTerm(self): + input_tensor = tf.SparseTensor( + [[0, 0], [0, 1], [0, 2], [1, 0], [1, 3]], + [(3/5), (1/5), (1/5), (1/2), (1/2)], + [2, 4]) + output_tensor = mappers._count_docs_with_term(input_tensor) + with tf.Session(): + output = output_tensor.eval() + self.assertAllEqual([[2, 1, 1, 1]], output) + + def testCountDocsWithTermUnusedTerm(self): + input_tensor = tf.SparseTensor( + [[0, 0], [0, 2], [1, 0], [1, 3]], + [(3/5), (1/5), (1/2), (1/2)], + [2, 4]) + output_tensor = mappers._count_docs_with_term(input_tensor) + with tf.Session(): + output = output_tensor.eval() + self.assertAllEqual([[2, 0, 1, 1]], output) + + def testToTFIDF(self): + term_freq = tf.SparseTensor( + [[0, 0], [0, 1], [0, 2], [1, 0], [1, 3]], + [(3/5), (1/5), (1/5), (1/2), (1/2)], + [2, 4]) + reduced_term_freq = tf.constant([[2, 1, 1, 1]]) + output_tensor = mappers._to_tfidf(term_freq, reduced_term_freq, 2, True) + log_3_over_2 = 0.4054651 + self.assertSparseOutput( + expected_indices=[[0, 0], [0, 1], [0, 2], [1, 0], [1, 3]], + expected_values=[0, (1/5)*log_3_over_2, (1/5)*log_3_over_2, + 0, (1/2)*log_3_over_2], + expected_shape=[2, 4], + actual_sparse_tensor=output_tensor, + close_values=True) + + def testToTFIDFNotSmooth(self): + term_freq = tf.SparseTensor( + [[0, 0], [0, 1], [0, 2], [1, 0], [1, 3]], + [(3/5), (1/5), (1/5), (1/2), (1/2)], + [2, 4]) + reduced_term_freq = tf.constant([[2, 1, 1, 1]]) + output_tensor = mappers._to_tfidf(term_freq, reduced_term_freq, 2, False) + log_2_over_1 = 0.6931471 + self.assertSparseOutput( + expected_indices=[[0, 0], [0, 1], [0, 2], [1, 0], [1, 3]], + expected_values=[0, (1/5)*log_2_over_1, (1/5)*log_2_over_1, + 0, (1/2)*log_2_over_1], + expected_shape=[2, 4], + actual_sparse_tensor=output_tensor, + close_values=True) + + def testSplitTFIDF(self): + tfidfs = tf.SparseTensor( + [[0, 0], [0, 1], [2, 1], [2, 2]], + [0.23104906, 0.19178806, 0.14384104, 0.34657359], + [3, 4]) + + out_index, out_weight = mappers._split_tfidfs_to_outputs(tfidfs) + self.assertSparseOutput( + expected_indices=[[0, 0], [0, 1], [2, 0], [2, 1]], + expected_values=[0, 1, 1, 2], + expected_shape=[3, 2], + actual_sparse_tensor=out_index, + close_values=False) + self.assertSparseOutput( + expected_indices=[[0, 0], [0, 1], [2, 0], [2, 1]], + expected_values=[0.23104906, 0.19178806, 0.14384104, 0.34657359], + expected_shape=[3, 2], + actual_sparse_tensor=out_weight, + close_values=True) + + def testHashStringsNoKeyDenseInput(self): + strings = tf.constant(['Car', 'Bus', 'Tree']) + expected_output = [8, 4, 5] + + hash_buckets = 11 + hashed_strings = mappers.hash_strings(strings, hash_buckets) + with self.test_session() as sess: + output = sess.run(hashed_strings) + self.assertAllEqual(expected_output, output) + + def testHashStringsNoKeySparseInput(self): + strings = tf.SparseTensor(indices=[[0, 0], [0, 1], [1, 0]], + values=['Dog', 'Cat', ''], + dense_shape=[2, 2]) + hash_buckets = 17 + expected_indices = [[0, 0], [0, 1], [1, 0]] + expected_values = [12, 4, 11] + expected_shape = [2, 2] + hashed_strings = mappers.hash_strings(strings, hash_buckets) + self.assertSparseOutput( + expected_indices=expected_indices, + expected_values=expected_values, + expected_shape=expected_shape, + actual_sparse_tensor=hashed_strings, + close_values=False) + + def testHashStringsWithKeyDenseInput(self): + strings = tf.constant(['Cake', 'Pie', 'Sundae']) + expected_output = [6, 5, 6] + hash_buckets = 11 + hashed_strings = mappers.hash_strings(strings, hash_buckets, key=[123, 456]) + with self.test_session() as sess: + output = sess.run(hashed_strings) + self.assertAllEqual(expected_output, output) + + def testHashStringsWithKeySparseInput(self): + strings = tf.SparseTensor(indices=[[0, 0], [0, 1], [1, 0], [2, 0]], + values=['$$$', '%^#', '&$!#@', '$$$'], + dense_shape=[3, 2]) + hash_buckets = 173 + expected_indices = [[0, 0], [0, 1], [1, 0], [2, 0]] + expected_values = [16, 156, 9, 16] + expected_shape = [3, 2] + hashed_strings = mappers.hash_strings(strings, hash_buckets, key=[321, 555]) + self.assertSparseOutput( + expected_indices=expected_indices, + expected_values=expected_values, + expected_shape=expected_shape, + actual_sparse_tensor=hashed_strings, + close_values=False) if __name__ == '__main__': diff --git a/tensorflow_transform/saved/input_fn_maker.py b/tensorflow_transform/saved/input_fn_maker.py index 38f0cf51..abdfb28d 100644 --- a/tensorflow_transform/saved/input_fn_maker.py +++ b/tensorflow_transform/saved/input_fn_maker.py @@ -106,7 +106,8 @@ def default_transforming_serving_input_fn(): record_defaults = [] for k in raw_keys: - if column_schemas[k].representation.default_value: + if column_schemas[k].representation.default_value is not None: + # Note that 0 and '' are valid defaults. value = tf.constant([column_schemas[k].representation.default_value], dtype=column_schemas[k].domain.dtype) else: @@ -133,6 +134,67 @@ def default_transforming_serving_input_fn(): return default_transforming_serving_input_fn +def build_json_example_transforming_serving_input_fn( + raw_metadata, + transform_savedmodel_dir, + raw_label_keys, + raw_feature_keys=None, + convert_scalars_to_vectors=True): + """Creates input_fn that applies transforms to raw data formatted in json. + + The json is formatted as tf.examples. For example, one input row could contain + the string for + + {"features": {"feature": {"name": {"int64List": {"value": [42]}}}}} + + which encodes an example containing only feature column 'name' with value 42. + + Args: + raw_metadata: a `DatasetMetadata` object describing the raw data. + transform_savedmodel_dir: a SavedModel directory produced by tf.Transform + embodying a transformation function to be applied to incoming raw data. + raw_label_keys: A list of string keys of the raw labels to be used. These + labels are removed from the serving graph. To build a serving function + that expects labels in the input at serving time, pass raw_labels_keys=[]. + raw_feature_keys: A list of string keys of the raw features to be used. + If None or empty, defaults to all features except labels. + convert_scalars_to_vectors: Boolean specifying whether this input_fn should + convert scalars into 1-d vectors. This is necessary if the inputs will be + used with `FeatureColumn`s as `FeatureColumn`s cannot accept scalar + inputs. Default: True. + + Returns: + An input_fn suitable for serving that applies transforms to raw data in + tf.Examples. + """ + + raw_feature_spec = raw_metadata.schema.as_feature_spec() + raw_feature_keys = _prepare_feature_keys(raw_metadata, + raw_label_keys, + raw_feature_keys) + raw_serving_feature_spec = {key: raw_feature_spec[key] + for key in raw_feature_keys} + + def _serving_input_fn(): + """Applies transforms to raw data in json-example strings.""" + + json_example_placeholder = tf.placeholder(tf.string, shape=[None]) + example_strings = tf.decode_json_example(json_example_placeholder) + raw_features = tf.parse_example(example_strings, raw_serving_feature_spec) + inputs = {"json_example": json_example_placeholder} + + _, transformed_features = ( + saved_transform_io.partially_apply_saved_transform( + transform_savedmodel_dir, raw_features)) + + if convert_scalars_to_vectors: + transformed_features = _convert_scalars_to_vectors(transformed_features) + + return input_fn_utils.InputFnOps(transformed_features, None, inputs) + + return _serving_input_fn + + def build_parsing_transforming_serving_input_fn( raw_metadata, transform_savedmodel_dir, diff --git a/tensorflow_transform/saved/input_fn_maker_test.py b/tensorflow_transform/saved/input_fn_maker_test.py index d3173499..b64a6212 100644 --- a/tensorflow_transform/saved/input_fn_maker_test.py +++ b/tensorflow_transform/saved/input_fn_maker_test.py @@ -17,6 +17,7 @@ from __future__ import division from __future__ import print_function +import json import os import tempfile @@ -63,6 +64,43 @@ def _make_transformed_schema(shape): class InputFnMakerTest(unittest.TestCase): + def test_build_csv_transforming_serving_input_fn_with_defaults(self): + feed_dict = [',,'] + + basedir = tempfile.mkdtemp() + + raw_metadata = dataset_metadata.DatasetMetadata( + schema=_make_raw_schema([])) + + transform_savedmodel_dir = os.path.join(basedir, 'transform-savedmodel') + _write_transform_savedmodel(transform_savedmodel_dir) + + serving_input_fn = ( + input_fn_maker.build_csv_transforming_serving_input_fn( + raw_metadata=raw_metadata, + raw_keys=['raw_a', 'raw_b', 'raw_label'], + transform_savedmodel_dir=transform_savedmodel_dir)) + + with tf.Graph().as_default(): + with tf.Session().as_default() as session: + outputs, labels, inputs = serving_input_fn() + feed_inputs = {inputs['csv_example']: feed_dict} + transformed_a, transformed_b, transformed_label = session.run( + [outputs['transformed_a'], outputs['transformed_b'], + outputs['transformed_label']], + feed_dict=feed_inputs) + + # Note the feed dict is empy. So these values come from the defaults + # in _make_raw_schema() + self.assertEqual(1, transformed_a[0][0]) + self.assertEqual(-1, transformed_b[0][0]) + self.assertEqual(-1000, transformed_label[0][0]) + self.assertItemsEqual( + outputs, + {'transformed_a', 'transformed_b', 'transformed_label'}) + self.assertIsNone(labels) + self.assertEqual(set(inputs.keys()), {'csv_example'}) + def test_build_csv_transforming_serving_input_fn_with_label(self): feed_dict = ['15,6,1', '12,17,2'] @@ -102,6 +140,79 @@ def test_build_csv_transforming_serving_input_fn_with_label(self): self.assertIsNone(labels) self.assertEqual(set(inputs.keys()), {'csv_example'}) + def test_build_json_example_transforming_serving_input_fn(self): + example_all = { + 'features': { + 'feature': { + 'raw_a': { + 'int64List': { + 'value': [42] + } + }, + 'raw_b': { + 'int64List': { + 'value': [43] + } + }, + 'raw_label': { + 'int64List': { + 'value': [44] + } + } + } + } + } + # Default values for raw_a and raw_b come from _make_raw_schema() + example_missing = { + 'features': { + 'feature': { + 'raw_label': { + 'int64List': { + 'value': [3] + } + } + } + } + } + feed_dict = [json.dumps(example_all), json.dumps(example_missing)] + + basedir = tempfile.mkdtemp() + + raw_metadata = dataset_metadata.DatasetMetadata( + schema=_make_raw_schema([])) + + transform_savedmodel_dir = os.path.join(basedir, 'transform-savedmodel') + _write_transform_savedmodel(transform_savedmodel_dir) + + serving_input_fn = ( + input_fn_maker.build_json_example_transforming_serving_input_fn( + raw_metadata=raw_metadata, + raw_label_keys=[], + raw_feature_keys=['raw_a', 'raw_b', 'raw_label'], + transform_savedmodel_dir=transform_savedmodel_dir)) + + with tf.Graph().as_default(): + with tf.Session().as_default() as session: + outputs, labels, inputs = serving_input_fn() + feed_inputs = {inputs['json_example']: feed_dict + } + transformed_a, transformed_b, transformed_label = session.run( + [outputs['transformed_a'], outputs['transformed_b'], + outputs['transformed_label']], + feed_dict=feed_inputs) + + self.assertEqual(85, transformed_a[0][0]) + self.assertEqual(-1, transformed_b[0][0]) + self.assertEqual(44000, transformed_label[0][0]) + self.assertEqual(1, transformed_a[1][0]) + self.assertEqual(-1, transformed_b[1][0]) + self.assertEqual(3000, transformed_label[1][0]) + self.assertItemsEqual( + outputs, + {'transformed_a', 'transformed_b', 'transformed_label'}) + self.assertIsNone(labels) + self.assertEqual(set(inputs.keys()), {'json_example'}) + def test_build_parsing_transforming_serving_input_fn_scalars(self): self._test_build_parsing_transforming_serving_input_fn( _make_raw_schema([]))