Skip to content

KeyValueTensorInitializer based on tft.count_per_key output raises AssertionError: Tried to export a function which references untracked resource Tensor #240

Open
@jccarles

Description

@jccarles

Hello TFT team,

I am migrating to tensorflow transform 1.0 and I am running into a tracing issue with tf2 behavior enabled. I wish to instantiate a lookup table based on a result of a tensorflow-transform analyzer but when tft_beam.AnalyzeDataset tries to save the transformation graph I run into an error referring to an "untracked resource Tensor".

Versions

  • tensorflow==2.5.0
  • tensorflow-transform==1.0.0
  • apache-beam==2.29.0

Steps to reproduce

I created a small snippet which reproduce the error, it create few data examples and a basic preprocessing_fn which instantiate a KeyValueTensorInitializer based on the result of the analyzer tft.count_per_key.

import tempfile
import tensorflow as tf
import tensorflow_transform as tft
import tensorflow_transform.beam as tft_beam
from tensorflow_transform.tf_metadata import dataset_metadata, schema_utils

_TEMP_DIR = tempfile.mkdtemp()
_LABEL_KEY = "label"
_INT_DTYPE = tf.int64
_FLOAT_DTYPE = tf.float32


def _generate_data_vector():
    train_raw_data = [
        {_LABEL_KEY: [1]},
        {_LABEL_KEY: [0]},
        {_LABEL_KEY: [0]},
        {_LABEL_KEY: [0]},
    ]
    raw_data_metadata = dataset_metadata.DatasetMetadata(
        schema_utils.schema_from_feature_spec(
            {_LABEL_KEY: tf.io.FixedLenFeature(shape=[1], dtype=_INT_DTYPE),}
        )
    )
    
    return train_raw_data, raw_data_metadata


def example_label_preprocessing(tensor: tf.Tensor) -> tf.Tensor:
    tensor = tf.squeeze(tensor, axis=-1)
    keys, values = tft.count_per_key(tensor)
    initializer = tf.lookup.KeyValueTensorInitializer(
        keys=keys,
        values=tf.cast(values, tf.float32),
    )
    
    table = tf.lookup.StaticHashTable(
        initializer,
        default_value=0.0,
    )
    return table.lookup(tensor)


def _train_and_retrieve_trained_data(train_data, feature_name):
    def preprocessing_fn(inputs):
        label = inputs[_LABEL_KEY]
        return {feature_name: example_label_preprocessing(tensor=label)}
    
    # apply preprocessing function to retrieve transform ops
    train_transformed = train_data | tft_beam.AnalyzeAndTransformDataset(
        preprocessing_fn
    )
    return train_transformed


if __name__ == "__main__":
    with tft_beam.Context(
            temp_dir=_TEMP_DIR , force_tf_compat_v1=False
    ):
        _data = _generate_data_vector()
        _FEATURE_NAME = "feature"
        trained_data = _train_and_retrieve_trained_data(_data, _FEATURE_NAME)
        print(trained_data)

Stack trace

When the above script is run I get the following error:

WARNING:tensorflow:You are passing instance dicts and DatasetMetadata to TFT which will not provide optimal performance. Consider following the TFT guide to upgrade to the TFXIO format (Apache Arrow RecordBatch).
WARNING:tensorflow:You are passing instance dicts and DatasetMetadata to TFT which will not provide optimal performance. Consider following the TFT guide to upgrade to the TFXIO format (Apache Arrow RecordBatch).
2021-06-18 11:01:48.849130: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
WARNING:tensorflow:From /Users/jccarles/PycharmProjects/machine-learning-pythonNew/tfx_1/lib/python3.7/site-packages/tensorflow_transform/tf_utils.py:266: Tensor.experimental_ref (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use ref() instead.
WARNING:tensorflow:From /Users/jccarles/PycharmProjects/machine-learning-pythonNew/tfx_1/lib/python3.7/site-packages/tensorflow_transform/tf_utils.py:266: Tensor.experimental_ref (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use ref() instead.
WARNING:tensorflow:Tables initialized inside a tf.function will be re-initialized on every invocation of the function. This re-initialization can have significant impact on performance. Consider lifting them out of the graph context using `tf.init_scope`.
WARNING:tensorflow:Tables initialized inside a tf.function will be re-initialized on every invocation of the function. This re-initialization can have significant impact on performance. Consider lifting them out of the graph context using `tf.init_scope`.
2021-06-18 11:01:49.914210: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:176] None of the MLIR Optimization Passes are enabled (registered 2)
WARNING:tensorflow:You are passing instance dicts and DatasetMetadata to TFT which will not provide optimal performance. Consider following the TFT guide to upgrade to the TFXIO format (Apache Arrow RecordBatch).
WARNING:tensorflow:You are passing instance dicts and DatasetMetadata to TFT which will not provide optimal performance. Consider following the TFT guide to upgrade to the TFXIO format (Apache Arrow RecordBatch).
WARNING:apache_beam.options.pipeline_options:Discarding unparseable args: ['/Users/jccarles/PycharmProjects/machine-learning-pythonNew/sandbox/jc/reproduce_bug_tft.py']
2021-06-18 11:01:52.406934: W tensorflow/python/util/util.cc:348] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
Traceback (most recent call last):
  File "apache_beam/runners/common.py", line 1233, in apache_beam.runners.common.DoFnRunner.process
  File "apache_beam/runners/common.py", line 762, in apache_beam.runners.common.PerWindowInvoker.invoke_process
  File "apache_beam/runners/common.py", line 887, in apache_beam.runners.common.PerWindowInvoker._invoke_process_per_window
  File "/Users/jccarles/PycharmProjects/machine-learning-pythonNew/tfx_1/lib/python3.7/site-packages/apache_beam/transforms/core.py", line 1586, in <lambda>
    wrapper = lambda x, *args, **kwargs: [fn(x, *args, **kwargs)]
  File "/Users/jccarles/PycharmProjects/machine-learning-pythonNew/tfx_1/lib/python3.7/site-packages/tensorflow_transform/beam/impl.py", line 684, in _create_v2_saved_model
    output_keys_to_name_map)
  File "/Users/jccarles/PycharmProjects/machine-learning-pythonNew/tfx_1/lib/python3.7/site-packages/tensorflow_transform/impl_helper.py", line 721, in trace_and_write_v2_saved_model
    tensor_replacement_map, output_keys_to_name_map)
  File "/Users/jccarles/PycharmProjects/machine-learning-pythonNew/tfx_1/lib/python3.7/site-packages/tensorflow_transform/impl_helper.py", line 663, in _trace_and_write_transform_fn
    saved_model_dir)
  File "/Users/jccarles/PycharmProjects/machine-learning-pythonNew/tfx_1/lib/python3.7/site-packages/tensorflow_transform/saved/saved_transform_io_v2.py", line 489, in write_v2_saved_model
    tf.saved_model.save(module, saved_model_dir)
  File "/Users/jccarles/PycharmProjects/machine-learning-pythonNew/tfx_1/lib/python3.7/site-packages/tensorflow/python/saved_model/save.py", line 1066, in save
    raise_metadata_warning=True)
  File "/Users/jccarles/PycharmProjects/machine-learning-pythonNew/tfx_1/lib/python3.7/site-packages/tensorflow/python/saved_model/save.py", line 1104, in save_and_return_nodes
    raise_metadata_warning))
  File "/Users/jccarles/PycharmProjects/machine-learning-pythonNew/tfx_1/lib/python3.7/site-packages/tensorflow/python/saved_model/save.py", line 1291, in _build_meta_graph
    raise_metadata_warning)
  File "/Users/jccarles/PycharmProjects/machine-learning-pythonNew/tfx_1/lib/python3.7/site-packages/tensorflow/python/saved_model/save.py", line 1225, in _build_meta_graph_impl
    options.namespace_whitelist)
  File "/Users/jccarles/PycharmProjects/machine-learning-pythonNew/tfx_1/lib/python3.7/site-packages/tensorflow/python/saved_model/save.py", line 691, in _fill_meta_graph_def
    [], resource_map))
  File "/Users/jccarles/PycharmProjects/machine-learning-pythonNew/tfx_1/lib/python3.7/site-packages/tensorflow/python/saved_model/save.py", line 534, in _call_function_with_mapped_captures
    resource_map)
  File "/Users/jccarles/PycharmProjects/machine-learning-pythonNew/tfx_1/lib/python3.7/site-packages/tensorflow/python/saved_model/save.py", line 457, in _map_captures_to_created_tensors
    "\n".join([repr(obj) for obj in trackable_referrers])))
AssertionError: Tried to export a function which references untracked resource Tensor("key_value_init/LookupTableImportV2/count_per_key/StringToNumber:0", shape=(None,), dtype=int64). TensorFlow objects (e.g. tf.Variable) captured by functions must be tracked by assigning them to an attribute of a tracked object or assigned to an attribute of the main object directly.

Trackable Python objects referring to this tensor (from gc.get_referrers, limited to two hops):
<tensorflow.python.ops.lookup_ops.KeyValueTensorInitializer object at 0x7f885bef5f98>

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/Users/jccarles/PycharmProjects/machine-learning-pythonNew/sandbox/jc/reproduce_bug_tft.py", line 62, in <module>
    trained_data = _train_and_retrieve_trained_data(_data, _FEATURE_NAME)
  File "/Users/jccarles/PycharmProjects/machine-learning-pythonNew/sandbox/jc/reproduce_bug_tft.py", line 51, in _train_and_retrieve_trained_data
    preprocessing_fn
  File "/Users/jccarles/PycharmProjects/machine-learning-pythonNew/tfx_1/lib/python3.7/site-packages/apache_beam/transforms/ptransform.py", line 581, in __ror__
    p.run().wait_until_finish()
  File "/Users/jccarles/PycharmProjects/machine-learning-pythonNew/tfx_1/lib/python3.7/site-packages/apache_beam/pipeline.py", line 561, in run
    return self.runner.run_pipeline(self, self._options)
  File "/Users/jccarles/PycharmProjects/machine-learning-pythonNew/tfx_1/lib/python3.7/site-packages/apache_beam/runners/direct/direct_runner.py", line 133, in run_pipeline
    return runner.run_pipeline(pipeline, options)
  File "/Users/jccarles/PycharmProjects/machine-learning-pythonNew/tfx_1/lib/python3.7/site-packages/apache_beam/runners/portability/fn_api_runner/fn_runner.py", line 194, in run_pipeline
    pipeline.to_runner_api(default_environment=self._default_environment))
  File "/Users/jccarles/PycharmProjects/machine-learning-pythonNew/tfx_1/lib/python3.7/site-packages/apache_beam/runners/portability/fn_api_runner/fn_runner.py", line 204, in run_via_runner_api
    return self.run_stages(stage_context, stages)
  File "/Users/jccarles/PycharmProjects/machine-learning-pythonNew/tfx_1/lib/python3.7/site-packages/apache_beam/runners/portability/fn_api_runner/fn_runner.py", line 370, in run_stages
    bundle_context_manager,
  File "/Users/jccarles/PycharmProjects/machine-learning-pythonNew/tfx_1/lib/python3.7/site-packages/apache_beam/runners/portability/fn_api_runner/fn_runner.py", line 566, in _run_stage
    bundle_manager)
  File "/Users/jccarles/PycharmProjects/machine-learning-pythonNew/tfx_1/lib/python3.7/site-packages/apache_beam/runners/portability/fn_api_runner/fn_runner.py", line 606, in _run_bundle
    data_input, data_output, input_timers, expected_timer_output)
  File "/Users/jccarles/PycharmProjects/machine-learning-pythonNew/tfx_1/lib/python3.7/site-packages/apache_beam/runners/portability/fn_api_runner/fn_runner.py", line 907, in process_bundle
    result_future = self._worker_handler.control_conn.push(process_bundle_req)
  File "/Users/jccarles/PycharmProjects/machine-learning-pythonNew/tfx_1/lib/python3.7/site-packages/apache_beam/runners/portability/fn_api_runner/worker_handlers.py", line 381, in push
    response = self.worker.do_instruction(request)
  File "/Users/jccarles/PycharmProjects/machine-learning-pythonNew/tfx_1/lib/python3.7/site-packages/apache_beam/runners/worker/sdk_worker.py", line 607, in do_instruction
    getattr(request, request_type), request.instruction_id)
  File "/Users/jccarles/PycharmProjects/machine-learning-pythonNew/tfx_1/lib/python3.7/site-packages/apache_beam/runners/worker/sdk_worker.py", line 644, in process_bundle
    bundle_processor.process_bundle(instruction_id))
  File "/Users/jccarles/PycharmProjects/machine-learning-pythonNew/tfx_1/lib/python3.7/site-packages/apache_beam/runners/worker/bundle_processor.py", line 1001, in process_bundle
    element.data)
  File "/Users/jccarles/PycharmProjects/machine-learning-pythonNew/tfx_1/lib/python3.7/site-packages/apache_beam/runners/worker/bundle_processor.py", line 229, in process_encoded
    self.output(decoded_value)
  File "apache_beam/runners/worker/operations.py", line 356, in apache_beam.runners.worker.operations.Operation.output
  File "apache_beam/runners/worker/operations.py", line 358, in apache_beam.runners.worker.operations.Operation.output
  File "apache_beam/runners/worker/operations.py", line 220, in apache_beam.runners.worker.operations.SingletonConsumerSet.receive
  File "apache_beam/runners/worker/operations.py", line 717, in apache_beam.runners.worker.operations.DoOperation.process
  File "apache_beam/runners/worker/operations.py", line 718, in apache_beam.runners.worker.operations.DoOperation.process
  File "apache_beam/runners/common.py", line 1235, in apache_beam.runners.common.DoFnRunner.process
  File "apache_beam/runners/common.py", line 1300, in apache_beam.runners.common.DoFnRunner._reraise_augmented
  File "apache_beam/runners/common.py", line 1233, in apache_beam.runners.common.DoFnRunner.process
  File "apache_beam/runners/common.py", line 581, in apache_beam.runners.common.SimpleInvoker.invoke_process
  File "apache_beam/runners/common.py", line 1395, in apache_beam.runners.common._OutputProcessor.process_outputs
  File "apache_beam/runners/worker/operations.py", line 220, in apache_beam.runners.worker.operations.SingletonConsumerSet.receive
  File "apache_beam/runners/worker/operations.py", line 717, in apache_beam.runners.worker.operations.DoOperation.process
  File "apache_beam/runners/worker/operations.py", line 718, in apache_beam.runners.worker.operations.DoOperation.process
  File "apache_beam/runners/common.py", line 1235, in apache_beam.runners.common.DoFnRunner.process
  File "apache_beam/runners/common.py", line 1300, in apache_beam.runners.common.DoFnRunner._reraise_augmented
  File "apache_beam/runners/common.py", line 1233, in apache_beam.runners.common.DoFnRunner.process
  File "apache_beam/runners/common.py", line 581, in apache_beam.runners.common.SimpleInvoker.invoke_process
  File "apache_beam/runners/common.py", line 1395, in apache_beam.runners.common._OutputProcessor.process_outputs
  File "apache_beam/runners/worker/operations.py", line 220, in apache_beam.runners.worker.operations.SingletonConsumerSet.receive
  File "apache_beam/runners/worker/operations.py", line 717, in apache_beam.runners.worker.operations.DoOperation.process
  File "apache_beam/runners/worker/operations.py", line 718, in apache_beam.runners.worker.operations.DoOperation.process
  File "apache_beam/runners/common.py", line 1235, in apache_beam.runners.common.DoFnRunner.process
  File "apache_beam/runners/common.py", line 1315, in apache_beam.runners.common.DoFnRunner._reraise_augmented
  File "/Users/jccarles/PycharmProjects/machine-learning-pythonNew/tfx_1/lib/python3.7/site-packages/future/utils/__init__.py", line 446, in raise_with_traceback
    raise exc.with_traceback(traceback)
  File "apache_beam/runners/common.py", line 1233, in apache_beam.runners.common.DoFnRunner.process
  File "apache_beam/runners/common.py", line 762, in apache_beam.runners.common.PerWindowInvoker.invoke_process
  File "apache_beam/runners/common.py", line 887, in apache_beam.runners.common.PerWindowInvoker._invoke_process_per_window
  File "/Users/jccarles/PycharmProjects/machine-learning-pythonNew/tfx_1/lib/python3.7/site-packages/apache_beam/transforms/core.py", line 1586, in <lambda>
    wrapper = lambda x, *args, **kwargs: [fn(x, *args, **kwargs)]
  File "/Users/jccarles/PycharmProjects/machine-learning-pythonNew/tfx_1/lib/python3.7/site-packages/tensorflow_transform/beam/impl.py", line 684, in _create_v2_saved_model
    output_keys_to_name_map)
  File "/Users/jccarles/PycharmProjects/machine-learning-pythonNew/tfx_1/lib/python3.7/site-packages/tensorflow_transform/impl_helper.py", line 721, in trace_and_write_v2_saved_model
    tensor_replacement_map, output_keys_to_name_map)
  File "/Users/jccarles/PycharmProjects/machine-learning-pythonNew/tfx_1/lib/python3.7/site-packages/tensorflow_transform/impl_helper.py", line 663, in _trace_and_write_transform_fn
    saved_model_dir)
  File "/Users/jccarles/PycharmProjects/machine-learning-pythonNew/tfx_1/lib/python3.7/site-packages/tensorflow_transform/saved/saved_transform_io_v2.py", line 489, in write_v2_saved_model
    tf.saved_model.save(module, saved_model_dir)
  File "/Users/jccarles/PycharmProjects/machine-learning-pythonNew/tfx_1/lib/python3.7/site-packages/tensorflow/python/saved_model/save.py", line 1066, in save
    raise_metadata_warning=True)
  File "/Users/jccarles/PycharmProjects/machine-learning-pythonNew/tfx_1/lib/python3.7/site-packages/tensorflow/python/saved_model/save.py", line 1104, in save_and_return_nodes
    raise_metadata_warning))
  File "/Users/jccarles/PycharmProjects/machine-learning-pythonNew/tfx_1/lib/python3.7/site-packages/tensorflow/python/saved_model/save.py", line 1291, in _build_meta_graph
    raise_metadata_warning)
  File "/Users/jccarles/PycharmProjects/machine-learning-pythonNew/tfx_1/lib/python3.7/site-packages/tensorflow/python/saved_model/save.py", line 1225, in _build_meta_graph_impl
    options.namespace_whitelist)
  File "/Users/jccarles/PycharmProjects/machine-learning-pythonNew/tfx_1/lib/python3.7/site-packages/tensorflow/python/saved_model/save.py", line 691, in _fill_meta_graph_def
    [], resource_map))
  File "/Users/jccarles/PycharmProjects/machine-learning-pythonNew/tfx_1/lib/python3.7/site-packages/tensorflow/python/saved_model/save.py", line 534, in _call_function_with_mapped_captures
    resource_map)
  File "/Users/jccarles/PycharmProjects/machine-learning-pythonNew/tfx_1/lib/python3.7/site-packages/tensorflow/python/saved_model/save.py", line 457, in _map_captures_to_created_tensors
    "\n".join([repr(obj) for obj in trackable_referrers])))
AssertionError: Tried to export a function which references untracked resource Tensor("key_value_init/LookupTableImportV2/count_per_key/StringToNumber:0", shape=(None,), dtype=int64). TensorFlow objects (e.g. tf.Variable) captured by functions must be tracked by assigning them to an attribute of a tracked object or assigned to an attribute of the main object directly.

Trackable Python objects referring to this tensor (from gc.get_referrers, limited to two hops):
<tensorflow.python.ops.lookup_ops.KeyValueTensorInitializer object at 0x7f885bef5f98> [while running 'AnalyzeAndTransformDataset/AnalyzeDataset/CreateSavedModelForAnalyzerInputs[Phase0][tf_v2_only]/CreateSavedModel']

I am unsure what is causing the issue although i do not encounter it when force_tf_compat_v1=True. Any help would be appreciated.

Metadata

Metadata

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions