Skip to content

fix: predict with multiple outputs #1065

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 11, 2023

Conversation

fryguy1013
Copy link
Contributor

It seems like this is how the code worked before a refactoring, but I'm not sure.

@Wanglongzhi2001
Copy link
Collaborator

No one seems to have refactored this file, but if you're having trouble with model.predict, would you mind providing a short code for us to reproduce?

@fryguy1013
Copy link
Contributor Author

fryguy1013 commented May 11, 2023

 var inputs = tf.keras.Input(shape: (6, 8, 3), name: "main_input");

var conv2d = tf.keras.layers.Conv2D(32, kernel_size: (3, 3),
    activation: tf.keras.activations.Linear).Apply(inputs);

var valueHead = tf.keras.layers.Dense(units: 1, use_bias: false, activation: tf.keras.activations.Linear).Apply(conv2d);
var policyHead = tf.keras.layers.Dense(units: 8*6*3, use_bias: false, activation: tf.keras.activations.Linear).Apply(conv2d);

var model = tf.keras.Model(
    inputs: inputs,
    outputs: new Tensors(valueHead, policyHead),
    "predictions"
);

var inputNN = np.zeros(6 * 8 * 3, TF_DataType.TF_FLOAT);
inputNN = np.reshape(inputNN, (1, 6, 8, 3));
var preds = model.predict(new Tensors(inputNN));

This was my test. I would expect preds to be an enumerable with two elements in it.

And about the refactoring, I just mean that commit 271dcef added a second method with the current behavior (just having tmp_batch_outputs[0] returned), and then it was refactored in 0ee50d3 to use the newly created method.

@Wanglongzhi2001
Copy link
Collaborator

Looks good to me.@Oceania2018

@AsakusaRinne AsakusaRinne changed the title Fix predict with multiple outputs fix: predict with multiple outputs May 11, 2023
@fryguy1013
Copy link
Contributor Author

I think someone with permissions needs to re-run the semantic check job after changing the name of the PR.

@fryguy1013 fryguy1013 force-pushed the fix-multiple-outputs branch from 3e84804 to 93cd2b6 Compare May 11, 2023 06:22
@AsakusaRinne
Copy link
Collaborator

All the unit tests and examples has passed. Thanks a lot for your contribution. :) The automatic nightly release is available by adding https://www.myget.org/F/scisharp/api/v3/index.json to your nuget source after merge (may take several minutes to wait for the action completes).

@AsakusaRinne AsakusaRinne merged commit cd64ea9 into SciSharp:master May 11, 2023
@fryguy1013 fryguy1013 deleted the fix-multiple-outputs branch May 11, 2023 18:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants