Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@ class QuantizedTypeConverter : public TypeConverter {
static Type convertQuantizedType(QuantizedType quantizedType) {
return quantizedType.getStorageType();
}

static Type convertTensorType(TensorType tensorType) {
if (auto quantizedType = dyn_cast<QuantizedType>(tensorType.getElementType()))
if (auto quantizedType =
dyn_cast<QuantizedType>(tensorType.getElementType()))
return tensorType.clone(convertQuantizedType(quantizedType));
return tensorType;
}
Expand All @@ -50,7 +51,6 @@ class QuantizedTypeConverter : public TypeConverter {
}

public:

explicit QuantizedTypeConverter() {
addConversion([](Type type) { return type; });
addConversion(convertQuantizedType);
Expand All @@ -63,7 +63,8 @@ class QuantizedTypeConverter : public TypeConverter {
};

// Conversion pass
class StripFuncQuantTypes : public impl::StripFuncQuantTypesBase<StripFuncQuantTypes> {
class StripFuncQuantTypes
: public impl::StripFuncQuantTypesBase<StripFuncQuantTypes> {

// Return whether a type is considered legal when occurring in the header of
// a function or as an operand to a 'return' op.
Expand All @@ -74,11 +75,10 @@ class StripFuncQuantTypes : public impl::StripFuncQuantTypesBase<StripFuncQuantT
}

public:

void runOnOperation() override {

auto moduleOp = cast<ModuleOp>(getOperation());
auto* context = &getContext();
auto *context = &getContext();

QuantizedTypeConverter typeConverter;
ConversionTarget target(*context);
Expand Down Expand Up @@ -111,4 +111,3 @@ class StripFuncQuantTypes : public impl::StripFuncQuantTypesBase<StripFuncQuantT

} // namespace quant
} // namespace mlir

Loading