Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions symbolic_pymc/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,12 @@ def meta_reify_iter(rands):
return type(rands)(reified_rands), any_unreified


class MetaReificationError(Exception):
"""An exception type for errors encountered during the creation of base objects from meta objects."""

pass


class MetaSymbolType(abc.ABCMeta):
def __new__(cls, name, bases, clsdict):

Expand Down
146 changes: 81 additions & 65 deletions symbolic_pymc/tensorflow/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,16 @@

from google.protobuf.message import Message

from tensorflow.python.framework import tensor_util, op_def_registry, op_def_library, tensor_shape
from tensorflow.python.framework import (
tensor_util,
op_def_registry,
op_def_library,
tensor_shape,
ops,
)
from tensorflow.core.framework.op_def_pb2 import OpDef
from tensorflow.core.framework.node_def_pb2 import NodeDef

# from tensorflow.core.framework.tensor_shape_pb2 import TensorShapeProto

from tensorflow_probability import distributions as tfd


Expand All @@ -30,6 +34,7 @@
MetaSymbolType,
MetaOp,
MetaVariable,
MetaReificationError,
meta_reify_iter,
_metatize,
metatize,
Expand Down Expand Up @@ -60,52 +65,68 @@ class MetaOpDefLibrary(object):
}
opdef_signatures = {}

def __init__(self):
#
# We need this in order to construct "Const" tensors directly, since
# the "value" attr in a meta `NodeDef` is just a NumPy array and not
# the `TensorProto` expected by `raw_ops.Const`.
#
def mt_const(value, dtype, name=None):
return tf.raw_ops.Const(
value=tensor_util.make_tensor_proto(value), dtype=dtype, name=name
)

opdef = op_def_registry.get("Const")
self.opdef_signatures[opdef.name] = self.make_opdef_sig(opdef, mt_const)

@classmethod
def apply_op(cls, *args, **kwargs):
return op_def_library.apply_op(*args, **kwargs)
def get_op_info(cls, opdef):
"""Return the TF Python API function signature for a given `OpDef`.

Parameter
---------
opdef: str or `OpDef` object (meta or base)
"""
if isinstance(opdef, str):
opdef_name = opdef
opdef = op_def_registry.get(opdef_name)
else:
opdef_name = opdef.name

opdef_sig = cls.opdef_signatures.get(opdef_name, None)

if opdef_sig is None and opdef is not None:
opdef_func = getattr(tf.raw_ops, opdef.name, None)
opdef_sig = cls.make_opdef_sig(opdef, opdef_func)
cls.opdef_signatures[opdef.name] = opdef_sig

return opdef_sig

@classmethod
def make_opdef_sig(cls, opdef, opdef_py_func=None):
"""Create a `Signature` object for an `OpDef`.

Annotations are include so that one can partially verify arguments.
"""
input_args = OrderedDict([(a.name, a.type or a.type_attr) for a in opdef.input_arg])
attrs = OrderedDict([(a.name, a) for a in opdef.attr])

params = OrderedDict()
if opdef_py_func:
#
# We assume we're dealing with a function from `tf.raw_ops`.
# Those functions have only the necessary `input_arg`s and
# `attr` inputs as arguments.
# Those functions have only the necessary `input_arg`s and `attr`
# inputs as arguments.
#
opdef_func_sig = Signature.from_callable(opdef_py_func)
params = opdef_func_sig.parameters

# for name, param in opdef_func_sig.parameters.items():
# # We make positional parameters permissible (since the
# # functions in `tf.raw_ops` are keyword-only), and we use the
# # `tf.raw_ops` arguments to determine the *actual* required
# # arguments (because `OpDef`'s `input_arg`s and `attrs` aren't
# # exactly clear about that).
# if name in input_args:
# new_default = Parameter.empty
# new_annotation = input_args[name]
# else:
# new_default = None
# new_annotation = attrs.get(name, None)
# if new_annotation is not None:
# new_annotation = new_annotation.type
else:
#
# new_param = param.replace(
# kind=Parameter.POSITIONAL_OR_KEYWORD,
# default=new_default,
# annotation=new_annotation,
# )
# params[name] = new_param
# We're crafting an `Operation` at a low-level via `apply_op`
# (like the functions in `tf.raw_ops` do)
#
input_args = OrderedDict([(a.name, a.type or a.type_attr) for a in opdef.input_arg])
attrs = OrderedDict([(a.name, a) for a in opdef.attr])
params = OrderedDict()

else:
# We're crafting the Operation at a low-level via `apply_op`.
opdef_py_func = partial(op_def_lib.apply_op, opdef.name)
opdef_py_func = partial(op_def_library.apply_op, opdef.name)

for i_name, i_type in input_args.items():
p = Parameter(i_name, Parameter.POSITIONAL_OR_KEYWORD, annotation=i_type)
Expand Down Expand Up @@ -144,29 +165,6 @@ def make_opdef_sig(cls, opdef, opdef_py_func=None):
)
return opdef_sig, opdef_py_func

@classmethod
def get_op_info(cls, opdef):
"""Return the TF Python API function signature for a given `OpDef`.

Parameter
---------
opdef: str or `OpDef` object (meta or base)
"""
if isinstance(opdef, str):
opdef_name = opdef
opdef = op_def_registry.get(opdef_name)
else:
opdef_name = opdef.name

opdef_sig = cls.opdef_signatures.get(opdef_name, None)

if opdef_sig is None and opdef is not None:
opdef_func = getattr(tf.raw_ops, opdef.name, None)
opdef_sig = cls.make_opdef_sig(opdef, opdef_func)
cls.opdef_signatures[opdef.name] = cls.make_opdef_sig(opdef, opdef_func)

return opdef_sig


op_def_lib = MetaOpDefLibrary()

Expand All @@ -183,7 +181,6 @@ def _metatize_tf_object(obj):
def load_dispatcher():
"""Set/override dispatcher to default to TF objects."""

from tensorflow.python.framework.ops import EagerTensor
from tensorflow.python.ops.gen_linalg_ops import _SvdOutput

def _metatize_tf_svd(obj):
Expand All @@ -200,7 +197,7 @@ def _metatize_tf_eager(obj):
" (e.g. within `tensorflow.python.eager.context.graph_mode`)"
)

meta._metatize.add((EagerTensor,), _metatize_tf_eager)
meta._metatize.add((ops.EagerTensor,), _metatize_tf_eager)

meta._metatize.add((object,), _metatize_tf_object)
meta._metatize.add((HashableNDArray,), _metatize_tf_object)
Expand Down Expand Up @@ -599,12 +596,30 @@ def reify(self):
)

if not (op_inputs_unreified or op_attrs_unreified or isvar(self.name)):

apply_arguments = operator.input_args(*op_inputs, name=self.name, **op_attrs)
tf_out = operator._apply_func(**apply_arguments)
op_tf = tf_out.op

# TODO: Update NodeDef attrs?
#
# An operation with this name might already exist in the graph
#
try:
existing_op = ops.get_default_graph().get_operation_by_name(self.name)
except KeyError:
#
# There is no such `Operation`, so we attempt to create it
#
apply_arguments = operator.input_args(*op_inputs, name=self.name, **op_attrs)
tf_out = operator._apply_func(**apply_arguments)
op_tf = tf_out.op
else:
#
# An `Operation` with this name exists, let's make sure it's
# equivalent to this meta `Operation`
#
if self != mt(existing_op):
raise MetaReificationError(
f"An Operation with the name {self.name}"
" already exists in the graph and is not"
" equal to this meta object."
)
op_tf = existing_op

assert op_tf is not None
self._obj = op_tf
Expand Down Expand Up @@ -1149,4 +1164,5 @@ def __getattr__(self, obj):

mt = TFlowMetaAccessor()


load_dispatcher()
56 changes: 54 additions & 2 deletions tests/tensorflow/test_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,16 @@
from unification import var, isvar

from symbolic_pymc.utils import HashableNDArray
from symbolic_pymc.meta import MetaSymbol, disable_auto_reification, enable_lvar_defaults
from symbolic_pymc.meta import (MetaSymbol, disable_auto_reification,
enable_lvar_defaults)
from symbolic_pymc.tensorflow.meta import (TFlowMetaTensor,
TFlowMetaTensorShape,
TFlowMetaOp,
TFlowMetaOpDef,
TFlowMetaNodeDef,
TFlowMetaOperator,
MetaOpDefLibrary,
MetaReificationError,
mt)

from tests.tensorflow import run_in_graph_mode
Expand Down Expand Up @@ -212,7 +214,7 @@ def test_meta_basic():


@run_in_graph_mode
def test_meta_Op():
def test_meta_operation():

t1_tf = tf.convert_to_tensor([[1, 2, 3], [4, 5, 6]])
t2_tf = tf.convert_to_tensor([[7, 8, 9], [10, 11, 12]])
Expand Down Expand Up @@ -673,3 +675,53 @@ def test_global_options():
with tf.Graph().as_default(), enable_lvar_defaults('names'):
a_mt = mt(1.0)
assert isvar(a_mt.name)


@run_in_graph_mode
def test_meta_const():
"""Make sure we can create a Const tensor by hand."""

with tf.Graph().as_default():
one_mt = mt.const(1, 'int32', 'Const')

with tf.Graph().as_default():
another_one_mt = mt(1)

assert one_mt == another_one_mt
assert isinstance(one_mt.reify(), tf.Tensor)
assert one_mt.reify().op.type == 'Const'


@run_in_graph_mode
def test_meta_existing_names():

with tf.Graph().as_default():
one_mt = mt(1)
assert one_mt.op.name == 'Const'

# Clear-out the associated base variable
orig_one_tf = one_mt._obj
one_mt.reset()
one_mt.op.reset()
assert one_mt.obj is None
assert one_mt.op.obj is None

# Attempt to reify to a base variable
one_tf = one_mt.reify()
assert one_tf.op.name == 'Const'
# Make sure it's the first base variable we created
assert orig_one_tf is one_tf

two_mt = mt(2)
two_mt.op.node_def.name = 'Const'

# TODO FIXME: We shouldn't have to do this manually after changing a
# dependency.
two_mt.reset()
two_mt.op.reset()
assert two_mt.obj is None
assert two_mt.op.obj is None
assert two_mt.op.name == 'Const'

with pytest.raises(MetaReificationError):
two_mt.reify()