Skip to content

Functions #238

Open
Open
@rnett

Description

@rnett

Functions

This reliant on things described in #237, and being familiar with the python function semantics will help, although I given an overview.

ConcreteFunction

A ConcreteFunction, both here and in python, is a wrapper for a native function handle. This is a set of "compiled" code that can be ran, but the signature can't be changed. Performance optimizations depend on that signature being well-known, too, apparently, i.e. argument shapes should be specified, even if they aren't required on placeholders.

Python

Functions in python are done via tf.function. Since python is duck typed, the dtypes and shapes of inputs can change, so a function will create a new ConcreteFunction for each distinct realized input signature (this is called tracing). Type hints can be used to limit re-tracing.

Supported arguments and what they expose to the signature are:

  • Tensors -> Dtype and shape
  • List of tensors -> number of items
  • Map with tensor values -> keys
  • Python objects -> identity

Of note is that if python objects are used as arguments, a new ConcreteFunction will be made for each set of argument values, so this is discouraged.

Because of the way tracing works, python side effects are only executed on each trace, which is unreliable. They will always be executed on the first call, but the function may not be re-traced each call.

Python control flow is converted to TensorFlow control flow (this is impossible for us).

Variable creation and initialization is handled as described in #237, the only thing we need to know is that global variable creation is forbidden after the first trace.

Outputs can be a single tensor or a tuple of tensors.

Java

For our implementation, I'm envisioning something similar to ConcreteFunction's API, where you would have a method like:

class Example{
  private static Map<String, Operand<?>> defineFunc(Ops tf, Inputs inputs){
    final var x = inputs.input("x", TFloat32.class);
    final var y = tf.math.add(x, tf.constant(2.0));
    return Map.of("y", y);
  }
  
  public static void main(String[] args){
    MapFunction func = Function.define(Example::defineFunc);
  }
}

The actual syntax is mostly irreverent and will likely change, but a few things to note:

  • inputs are defined as the function is built. I don't like this very much (it allows for shenanigans using non-tensor inputs, e.g. adding n inputs where n is itself an input. This would throw a runtime error), but don't see a good way around it.
  • as such, we only need to return outputs

Note that we have no way of converting Java control flow to tensorflow's like is done in Python. This will be noted and highlighted in the docs. It's generally not an issue though, since we can't resolve tensors in functions, so any control flow will be dependent on Java args which will cause re-tracing on new values anyways (it's a bit sub-optimal, but works fine).

I also plan a Kotlin compiler plugin to do the transform if you annotate the function like Python does, but I'm not sure if officially supporting it is a good idea (the compiler API is unstable and undocumented, although it will be released eventually).

Inputs and Outputs

Using a Inputs builder instead of the existing signature builder is because it's necessary to allow for more complicated inputs (i.e. not just placeholders).

Like python, I would want to allow as inputs:

  • tensors (ofc) - input.input(Class<TType> dtype). Note that type can be family types like TFloating.
  • List of tensors - input.list, input.list(Class<TType> dtype) for a list of the same dtype
  • Map of tensor values - input.map, input.map(Class<TType> dtype)
  • Java objects - input.javaInput<T> (typing is done by casting, since the argument map is untyped anyways)

This would work very similarly to how Python handles it, creating a signature for each set of argument values (with the same criteria as Python's) and caching ConcreteFunctions based on that. The parameter list (i.e. names and types) would be static, and we would use it to get the argument signature without re-tracing the function. If a retrace resulted in a different parameter list an error would be thrown.

For outputs, I want to support single tensors, a list, and a String -> Operand map (shown above). The values of the list and map would be limited to tensors. The resulting Function objects will be type safe wrt their outputs (i.e. the call method returns List or Operand). I'll also add a type safe single input single output version.

Variables and Captures

As described in #237, global variables are created in an attached eager initScope, and limited to the first trace. This works since everything from the initScope are automatically captured when used. As described in #237, I would like a way to create variables once and remember them on further calls automatically, but this requires execution environment wide unique ids, which is prohibitive (or at least I can't think of a better way).

Operands from other environments that are accessible from the function definition can be captured by the closure. Only operands from the initScope will be captured automatically on use. Otherwise, you can use input.constCapture(x) to bring x into scope if it's from a compatible environment, or input.capture(() -> x) (the preferred way), which will reflect any updates to x.

Constant captures from eager sessions use tf.constant(x.asTensor), and lambda captures or captures from graphs work by adding an input, but providing the argument each time the function is called.

Operands from any eager environment can be captured, but if captures from a graph are used they all must come from the same graph, and the function can only be called from that graph, since the captured value needs to be accessible for calls. This is the case almost all the time anyways.

Note: this is largely an implementation detail, but functions can be inlined when being called from another function (and possible Graph, although that's a bit harder). I still need to work out exactly how that will work with captures and whatnot, but it should definitely be possible (Python does it).

Saving and Loading

Currently not supported. Described in #237 a bit, but essentially we would need to convert the initial values from an eager context to a graph, and the needed C apis aren't exposed yet (if it's even possible). There's workarounds like just saving the current value we could look into if this is very necessary.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions