diff --git a/symbolic_pymc/meta.py b/symbolic_pymc/meta.py index 64476ca..9425c6a 100644 --- a/symbolic_pymc/meta.py +++ b/symbolic_pymc/meta.py @@ -139,6 +139,12 @@ def meta_reify_iter(rands): return type(rands)(reified_rands), any_unreified +class MetaReificationError(Exception): + """An exception type for errors encountered during the creation of base objects from meta objects.""" + + pass + + class MetaSymbolType(abc.ABCMeta): def __new__(cls, name, bases, clsdict): diff --git a/symbolic_pymc/tensorflow/meta.py b/symbolic_pymc/tensorflow/meta.py index 67b0a12..68fad81 100644 --- a/symbolic_pymc/tensorflow/meta.py +++ b/symbolic_pymc/tensorflow/meta.py @@ -16,12 +16,16 @@ from google.protobuf.message import Message -from tensorflow.python.framework import tensor_util, op_def_registry, op_def_library, tensor_shape +from tensorflow.python.framework import ( + tensor_util, + op_def_registry, + op_def_library, + tensor_shape, + ops, +) from tensorflow.core.framework.op_def_pb2 import OpDef from tensorflow.core.framework.node_def_pb2 import NodeDef -# from tensorflow.core.framework.tensor_shape_pb2 import TensorShapeProto - from tensorflow_probability import distributions as tfd @@ -30,6 +34,7 @@ MetaSymbolType, MetaOp, MetaVariable, + MetaReificationError, meta_reify_iter, _metatize, metatize, @@ -60,9 +65,42 @@ class MetaOpDefLibrary(object): } opdef_signatures = {} + def __init__(self): + # + # We need this in order to construct "Const" tensors directly, since + # the "value" attr in a meta `NodeDef` is just a NumPy array and not + # the `TensorProto` expected by `raw_ops.Const`. + # + def mt_const(value, dtype, name=None): + return tf.raw_ops.Const( + value=tensor_util.make_tensor_proto(value), dtype=dtype, name=name + ) + + opdef = op_def_registry.get("Const") + self.opdef_signatures[opdef.name] = self.make_opdef_sig(opdef, mt_const) + @classmethod - def apply_op(cls, *args, **kwargs): - return op_def_library.apply_op(*args, **kwargs) + def get_op_info(cls, opdef): + """Return the TF Python API function signature for a given `OpDef`. + + Parameter + --------- + opdef: str or `OpDef` object (meta or base) + """ + if isinstance(opdef, str): + opdef_name = opdef + opdef = op_def_registry.get(opdef_name) + else: + opdef_name = opdef.name + + opdef_sig = cls.opdef_signatures.get(opdef_name, None) + + if opdef_sig is None and opdef is not None: + opdef_func = getattr(tf.raw_ops, opdef.name, None) + opdef_sig = cls.make_opdef_sig(opdef, opdef_func) + cls.opdef_signatures[opdef.name] = opdef_sig + + return opdef_sig @classmethod def make_opdef_sig(cls, opdef, opdef_py_func=None): @@ -70,42 +108,25 @@ def make_opdef_sig(cls, opdef, opdef_py_func=None): Annotations are include so that one can partially verify arguments. """ - input_args = OrderedDict([(a.name, a.type or a.type_attr) for a in opdef.input_arg]) - attrs = OrderedDict([(a.name, a) for a in opdef.attr]) - - params = OrderedDict() if opdef_py_func: + # # We assume we're dealing with a function from `tf.raw_ops`. - # Those functions have only the necessary `input_arg`s and - # `attr` inputs as arguments. + # Those functions have only the necessary `input_arg`s and `attr` + # inputs as arguments. + # opdef_func_sig = Signature.from_callable(opdef_py_func) params = opdef_func_sig.parameters - # for name, param in opdef_func_sig.parameters.items(): - # # We make positional parameters permissible (since the - # # functions in `tf.raw_ops` are keyword-only), and we use the - # # `tf.raw_ops` arguments to determine the *actual* required - # # arguments (because `OpDef`'s `input_arg`s and `attrs` aren't - # # exactly clear about that). - # if name in input_args: - # new_default = Parameter.empty - # new_annotation = input_args[name] - # else: - # new_default = None - # new_annotation = attrs.get(name, None) - # if new_annotation is not None: - # new_annotation = new_annotation.type + else: # - # new_param = param.replace( - # kind=Parameter.POSITIONAL_OR_KEYWORD, - # default=new_default, - # annotation=new_annotation, - # ) - # params[name] = new_param + # We're crafting an `Operation` at a low-level via `apply_op` + # (like the functions in `tf.raw_ops` do) + # + input_args = OrderedDict([(a.name, a.type or a.type_attr) for a in opdef.input_arg]) + attrs = OrderedDict([(a.name, a) for a in opdef.attr]) + params = OrderedDict() - else: - # We're crafting the Operation at a low-level via `apply_op`. - opdef_py_func = partial(op_def_lib.apply_op, opdef.name) + opdef_py_func = partial(op_def_library.apply_op, opdef.name) for i_name, i_type in input_args.items(): p = Parameter(i_name, Parameter.POSITIONAL_OR_KEYWORD, annotation=i_type) @@ -144,29 +165,6 @@ def make_opdef_sig(cls, opdef, opdef_py_func=None): ) return opdef_sig, opdef_py_func - @classmethod - def get_op_info(cls, opdef): - """Return the TF Python API function signature for a given `OpDef`. - - Parameter - --------- - opdef: str or `OpDef` object (meta or base) - """ - if isinstance(opdef, str): - opdef_name = opdef - opdef = op_def_registry.get(opdef_name) - else: - opdef_name = opdef.name - - opdef_sig = cls.opdef_signatures.get(opdef_name, None) - - if opdef_sig is None and opdef is not None: - opdef_func = getattr(tf.raw_ops, opdef.name, None) - opdef_sig = cls.make_opdef_sig(opdef, opdef_func) - cls.opdef_signatures[opdef.name] = cls.make_opdef_sig(opdef, opdef_func) - - return opdef_sig - op_def_lib = MetaOpDefLibrary() @@ -183,7 +181,6 @@ def _metatize_tf_object(obj): def load_dispatcher(): """Set/override dispatcher to default to TF objects.""" - from tensorflow.python.framework.ops import EagerTensor from tensorflow.python.ops.gen_linalg_ops import _SvdOutput def _metatize_tf_svd(obj): @@ -200,7 +197,7 @@ def _metatize_tf_eager(obj): " (e.g. within `tensorflow.python.eager.context.graph_mode`)" ) - meta._metatize.add((EagerTensor,), _metatize_tf_eager) + meta._metatize.add((ops.EagerTensor,), _metatize_tf_eager) meta._metatize.add((object,), _metatize_tf_object) meta._metatize.add((HashableNDArray,), _metatize_tf_object) @@ -599,12 +596,30 @@ def reify(self): ) if not (op_inputs_unreified or op_attrs_unreified or isvar(self.name)): - - apply_arguments = operator.input_args(*op_inputs, name=self.name, **op_attrs) - tf_out = operator._apply_func(**apply_arguments) - op_tf = tf_out.op - - # TODO: Update NodeDef attrs? + # + # An operation with this name might already exist in the graph + # + try: + existing_op = ops.get_default_graph().get_operation_by_name(self.name) + except KeyError: + # + # There is no such `Operation`, so we attempt to create it + # + apply_arguments = operator.input_args(*op_inputs, name=self.name, **op_attrs) + tf_out = operator._apply_func(**apply_arguments) + op_tf = tf_out.op + else: + # + # An `Operation` with this name exists, let's make sure it's + # equivalent to this meta `Operation` + # + if self != mt(existing_op): + raise MetaReificationError( + f"An Operation with the name {self.name}" + " already exists in the graph and is not" + " equal to this meta object." + ) + op_tf = existing_op assert op_tf is not None self._obj = op_tf @@ -1149,4 +1164,5 @@ def __getattr__(self, obj): mt = TFlowMetaAccessor() + load_dispatcher() diff --git a/tests/tensorflow/test_meta.py b/tests/tensorflow/test_meta.py index a2a8450..bb4bd25 100644 --- a/tests/tensorflow/test_meta.py +++ b/tests/tensorflow/test_meta.py @@ -15,7 +15,8 @@ from unification import var, isvar from symbolic_pymc.utils import HashableNDArray -from symbolic_pymc.meta import MetaSymbol, disable_auto_reification, enable_lvar_defaults +from symbolic_pymc.meta import (MetaSymbol, disable_auto_reification, + enable_lvar_defaults) from symbolic_pymc.tensorflow.meta import (TFlowMetaTensor, TFlowMetaTensorShape, TFlowMetaOp, @@ -23,6 +24,7 @@ TFlowMetaNodeDef, TFlowMetaOperator, MetaOpDefLibrary, + MetaReificationError, mt) from tests.tensorflow import run_in_graph_mode @@ -212,7 +214,7 @@ def test_meta_basic(): @run_in_graph_mode -def test_meta_Op(): +def test_meta_operation(): t1_tf = tf.convert_to_tensor([[1, 2, 3], [4, 5, 6]]) t2_tf = tf.convert_to_tensor([[7, 8, 9], [10, 11, 12]]) @@ -673,3 +675,53 @@ def test_global_options(): with tf.Graph().as_default(), enable_lvar_defaults('names'): a_mt = mt(1.0) assert isvar(a_mt.name) + + +@run_in_graph_mode +def test_meta_const(): + """Make sure we can create a Const tensor by hand.""" + + with tf.Graph().as_default(): + one_mt = mt.const(1, 'int32', 'Const') + + with tf.Graph().as_default(): + another_one_mt = mt(1) + + assert one_mt == another_one_mt + assert isinstance(one_mt.reify(), tf.Tensor) + assert one_mt.reify().op.type == 'Const' + + +@run_in_graph_mode +def test_meta_existing_names(): + + with tf.Graph().as_default(): + one_mt = mt(1) + assert one_mt.op.name == 'Const' + + # Clear-out the associated base variable + orig_one_tf = one_mt._obj + one_mt.reset() + one_mt.op.reset() + assert one_mt.obj is None + assert one_mt.op.obj is None + + # Attempt to reify to a base variable + one_tf = one_mt.reify() + assert one_tf.op.name == 'Const' + # Make sure it's the first base variable we created + assert orig_one_tf is one_tf + + two_mt = mt(2) + two_mt.op.node_def.name = 'Const' + + # TODO FIXME: We shouldn't have to do this manually after changing a + # dependency. + two_mt.reset() + two_mt.op.reset() + assert two_mt.obj is None + assert two_mt.op.obj is None + assert two_mt.op.name == 'Const' + + with pytest.raises(MetaReificationError): + two_mt.reify()