-
Notifications
You must be signed in to change notification settings - Fork 370
feat: support truncate long/double to int/float with option #407
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: inocsin <[email protected]>
|
@peri044 please review this PR |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You have to add this flag on the python api side as well. Here are the locations
core/conversion/var/Var.cpp
Outdated
| auto tensor = ptr_.ivalue->toTensor(); | ||
| if (tensor.scalar_type() == at::kLong && ctx->settings.truncate_long_and_double) { | ||
| weights = converters::Weights(ctx, tensor.toType(at::kInt)); | ||
| LOG_WARNING("Truncate kLong to kInt for IValue"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Warning: Truncating weight (constant in the graph) from kLong to kInt to indicate that only constants are affected.
core/conversion/var/Var.cpp
Outdated
| LOG_WARNING("Truncate kLong to kInt for IValue"); | ||
| } else if (tensor.scalar_type() == at::kDouble && ctx->settings.truncate_long_and_double) { | ||
| weights = converters::Weights(ctx, tensor.toType(at::kFloat)); | ||
| LOG_WARNING("Truncate kDouble to kFloat for IValue"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here as well w.r.t warning message
core/conversion/var/Var.cpp
Outdated
| if (isIValue()) { | ||
| auto weights = converters::Weights(ctx, ptr_.ivalue->toTensor()); | ||
|
|
||
| auto tensor = ptr_.ivalue->toTensor(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should also check these cases if the setting is not enable to give users a hint to turn on truncating
Signed-off-by: inocsin <[email protected]>
core/conversion/var/Var.cpp
Outdated
|
|
||
| auto tensor = ptr_.ivalue->toTensor(); | ||
| if ((tensor.scalar_type() == at::kLong || tensor.scalar_type() == at::kDouble) && !ctx->settings.truncate_long_and_double) { | ||
| TRTORCH_CHECK(0, "Unable to freeze tensor of type kLong/kDouble into constant layer, try to compile model with truncate_long_and_double ON"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lets just call the types: Int32, Int64, Float64, Float32 throughout the warning messages
core/conversion/var/Var.cpp
Outdated
|
|
||
| auto tensor = ptr_.ivalue->toTensor(); | ||
| if ((tensor.scalar_type() == at::kLong || tensor.scalar_type() == at::kDouble) && !ctx->settings.truncate_long_and_double) { | ||
| TRTORCH_CHECK(0, "Unable to freeze tensor of type kLong/kDouble into constant layer, try to compile model with truncate_long_and_double ON"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use TRTORCH_THROW_ERROR here
core/conversion/var/Var.cpp
Outdated
| } else if (tensor.scalar_type() == at::kLong && ctx->settings.truncate_long_and_double) { | ||
| weights = converters::Weights(ctx, tensor.toType(at::kInt)); | ||
| LOG_WARNING("Truncate kLong to kInt for IValue"); | ||
| LOG_WARNING("Warning: Truncating weight (constant in the graph) from kLong to kInt to indicate that only constants are affected."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@inocsin You can remove "to indicate that only constants are affected." This was just a message for you
Signed-off-by: inocsin <[email protected]>
narendasan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly looks good to go, one comment and resolve comments
core/conversion/var/Var.cpp
Outdated
| LOG_WARNING("Warning: Truncating weight (constant in the graph) from Int64 to Int32."); | ||
| } else if (tensor.scalar_type() == at::kDouble && ctx->settings.truncate_long_and_double) { | ||
| weights = converters::Weights(ctx, tensor.toType(at::kFloat)); | ||
| LOG_WARNING("Warning: Truncating weight (constant in the graph) from Float64 to Float32."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You dont need to say warning, the logging system will add that for you
Signed-off-by: inocsin <[email protected]>
Description
Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change.
Type of change
Please delete options that are not relevant and/or add your own.
Support truncate long/double to int/float with option, add CompileSpec support
Please close the previous issue support Long/Double type IValue #322
Checklist: