Description
Functions
This reliant on things described in #237, and being familiar with the python function semantics will help, although I given an overview.
ConcreteFunction
A ConcreteFunction
, both here and in python, is a wrapper for a native function handle. This is a set of "compiled" code that can be ran, but the signature can't be changed. Performance optimizations depend on that signature being well-known, too, apparently, i.e. argument shapes should be specified, even if they aren't required on placeholders.
Python
Functions in python are done via tf.function
. Since python is duck typed, the dtypes and shapes of inputs can change, so a function will create a new ConcreteFunction
for each distinct realized input signature (this is called tracing). Type hints can be used to limit re-tracing.
Supported arguments and what they expose to the signature are:
- Tensors -> Dtype and shape
- List of tensors -> number of items
- Map with tensor values -> keys
- Python objects -> identity
Of note is that if python objects are used as arguments, a new ConcreteFunction
will be made for each set of argument values, so this is discouraged.
Because of the way tracing works, python side effects are only executed on each trace, which is unreliable. They will always be executed on the first call, but the function may not be re-traced each call.
Python control flow is converted to TensorFlow control flow (this is impossible for us).
Variable creation and initialization is handled as described in #237, the only thing we need to know is that global variable creation is forbidden after the first trace.
Outputs can be a single tensor or a tuple of tensors.
Java
For our implementation, I'm envisioning something similar to ConcreteFunction
's API, where you would have a method like:
class Example{
private static Map<String, Operand<?>> defineFunc(Ops tf, Inputs inputs){
final var x = inputs.input("x", TFloat32.class);
final var y = tf.math.add(x, tf.constant(2.0));
return Map.of("y", y);
}
public static void main(String[] args){
MapFunction func = Function.define(Example::defineFunc);
}
}
The actual syntax is mostly irreverent and will likely change, but a few things to note:
- inputs are defined as the function is built. I don't like this very much (it allows for shenanigans using non-tensor inputs, e.g. adding
n
inputs wheren
is itself an input. This would throw a runtime error), but don't see a good way around it. - as such, we only need to return outputs
Note that we have no way of converting Java control flow to tensorflow's like is done in Python. This will be noted and highlighted in the docs. It's generally not an issue though, since we can't resolve tensors in functions, so any control flow will be dependent on Java args which will cause re-tracing on new values anyways (it's a bit sub-optimal, but works fine).
I also plan a Kotlin compiler plugin to do the transform if you annotate the function like Python does, but I'm not sure if officially supporting it is a good idea (the compiler API is unstable and undocumented, although it will be released eventually).
Inputs and Outputs
Using a Inputs
builder instead of the existing signature builder is because it's necessary to allow for more complicated inputs (i.e. not just placeholders).
Like python, I would want to allow as inputs:
- tensors (ofc) -
input.input(Class<TType> dtype)
. Note thattype
can be family types likeTFloating
. - List of tensors -
input.list
,input.list(Class<TType> dtype)
for a list of the same dtype - Map of tensor values -
input.map
,input.map(Class<TType> dtype)
- Java objects -
input.javaInput<T>
(typing is done by casting, since the argument map is untyped anyways)
This would work very similarly to how Python handles it, creating a signature for each set of argument values (with the same criteria as Python's) and caching ConcreteFunction
s based on that. The parameter list (i.e. names and types) would be static, and we would use it to get the argument signature without re-tracing the function. If a retrace resulted in a different parameter list an error would be thrown.
For outputs, I want to support single tensors, a list, and a String -> Operand
map (shown above). The values of the list and map would be limited to tensors. The resulting Function
objects will be type safe wrt their outputs (i.e. the call method returns List
or Operand
). I'll also add a type safe single input single output version.
Variables and Captures
As described in #237, global variables are created in an attached eager initScope
, and limited to the first trace. This works since everything from the initScope
are automatically captured when used. As described in #237, I would like a way to create variables once and remember them on further calls automatically, but this requires execution environment wide unique ids, which is prohibitive (or at least I can't think of a better way).
Operands from other environments that are accessible from the function definition can be captured by the closure. Only operands from the initScope
will be captured automatically on use. Otherwise, you can use input.constCapture(x)
to bring x
into scope if it's from a compatible environment, or input.capture(() -> x)
(the preferred way), which will reflect any updates to x
.
Constant captures from eager sessions use tf.constant(x.asTensor)
, and lambda captures or captures from graphs work by adding an input, but providing the argument each time the function is called.
Operands from any eager environment can be captured, but if captures from a graph are used they all must come from the same graph, and the function can only be called from that graph, since the captured value needs to be accessible for calls. This is the case almost all the time anyways.
Note: this is largely an implementation detail, but functions can be inlined when being called from another function (and possible Graph, although that's a bit harder). I still need to work out exactly how that will work with captures and whatnot, but it should definitely be possible (Python does it).
Saving and Loading
Currently not supported. Described in #237 a bit, but essentially we would need to convert the initial values from an eager context to a graph, and the needed C apis aren't exposed yet (if it's even possible). There's workarounds like just saving the current value we could look into if this is very necessary.