|
572 | 572 | "# can have significant performance benefits.\n", |
573 | 573 | "\n", |
574 | 574 | "class FruitTraceType(tf.types.experimental.TraceType):\n", |
575 | | - " def __init__(self, fruit_type):\n", |
576 | | - " self.fruit_type = fruit_type\n", |
| 575 | + " def __init__(self, fruit):\n", |
| 576 | + " self.fruit_type = type(fruit)\n", |
| 577 | + " self.fruit_value = fruit\n", |
577 | 578 | "\n", |
578 | 579 | " def is_subtype_of(self, other):\n", |
| 580 | + " # True if self subtypes `other` and `other`'s type matches FruitTraceType.\n", |
579 | 581 | " return (type(other) is FruitTraceType and\n", |
580 | 582 | " self.fruit_type is other.fruit_type)\n", |
581 | 583 | "\n", |
582 | 584 | " def most_specific_common_supertype(self, others):\n", |
| 585 | + " # `self` is the specific common supertype if all input types match it.\n", |
583 | 586 | " return self if all(self == other for other in others) else None\n", |
584 | 587 | "\n", |
| 588 | + " def placeholder_value(self, placeholder_context=None):\n", |
| 589 | + " # Use the fruit itself instead of the type for correct tracing.\n", |
| 590 | + " return self.fruit_value\n", |
| 591 | + "\n", |
585 | 592 | " def __eq__(self, other):\n", |
586 | 593 | " return type(other) is FruitTraceType and self.fruit_type == other.fruit_type\n", |
587 | 594 | " \n", |
|
591 | 598 | "class FruitWithTraceType:\n", |
592 | 599 | "\n", |
593 | 600 | " def __tf_tracing_type__(self, context):\n", |
594 | | - " return FruitTraceType(type(self))\n", |
| 601 | + " return FruitTraceType(self)\n", |
595 | 602 | "\n", |
596 | 603 | "class AppleWithTraceType(FruitWithTraceType):\n", |
597 | 604 | " flavor = tf.constant([1, 2])\n", |
|
1831 | 1838 | ], |
1832 | 1839 | "metadata": { |
1833 | 1840 | "colab": { |
1834 | | - "collapsed_sections": [], |
1835 | 1841 | "name": "function.ipynb", |
| 1842 | + "provenance": [], |
1836 | 1843 | "toc_visible": true |
1837 | 1844 | }, |
1838 | 1845 | "kernelspec": { |
|
0 commit comments