Skip to content

API convention to take runtime type parameters in constructors of tensor-type-parameterized classes? #201

Open
@deansher

Description

@deansher

Our API design deliberately uses lots of tensor-type-parameterized classes. Here's a typical example:

public class BinaryCrossentropy<U extends TNumber, T extends TNumber>
    extends MeanMetricWrapper<U, T> implements LossMetric<T> {

Of course, Java generic type parameters are erased at runtime, so the implementation of this class doesn't necessarily have Class objects for types U and T if it needs them.

In a few cases, our initial implementation of such classes has already found itself in this situation, so we have required Class parameters in constructors. Here's an example:

public class MeanMetricWrapper<U extends TNumber, T extends TNumber> extends Mean<U, T> {
   . . .
  protected MeanMetricWrapper(Ops tf, String name, long seed, Class<T> type) {
    super(tf, name, seed, type);
  }

But suppose we didn't need a stored runtime class when we originally implemented this class, but then eventually made changes that required it? (Or suppose we need the runtime type for U!) To evolve the constructor API, we would have to choose between unpleasant alternatives:

  • Add a new constructor that takes the runtime class, and make it a runtime error to use the new functionality if you didn't use that constructor.
  • Punt on adding the new functionality to the existing class hierarchy, perhaps instead creating a MeanMetricWrapperV2 and then also forking subclasses as needed.
  • Or make a breaking change by adding a runtime type parameter to existing constructors (which in this case would break a ton of subclass APIs).

I wonder whether we should make it a consistent pattern that classes parameterized by tensor types take the runtime types in their constructors?

In Java, this will make constructor invocation less convenient. Here's an example in existing code:

      BinaryCrossentropy<TFloat32, TFloat64> instance =
          new BinaryCrossentropy<>(tf, "BCE_testUnweighted", false, 0, 1001L, TFloat64.class);

Note the TFloat64.class at the end of the call. In Java 1.8, this is entirely an added burden: it is redundant with the second TFloat64 type parameter, but both are required. However, with the addition of local variable type inference in Java 10, this redundancy will be eliminated. The only extra burden then will be the .class. In Kotlin, this could be entirely smoothed over with inline factory methods.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions