Skip to content

Commit 885b096

Browse files
author
Prashant Kumar
committed
[WEB] Cache the compiled module.
-- Don't compile the module again and again.
1 parent a886cba commit 885b096

File tree

2 files changed

+33
-19
lines changed

2 files changed

+33
-19
lines changed

web/models/albert_maskfill.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
MAX_SEQUENCE_LENGTH = 512
88
BATCH_SIZE = 1
9+
COMPILE_MODULE = None
910

1011

1112
class AlbertModule(torch.nn.Module):
@@ -54,18 +55,23 @@ def top5_possibilities(text, inputs, token_logits):
5455

5556

5657
def albert_maskfill_inf(masked_text):
58+
global COMPILE_MODULE
5759
inputs = preprocess_data(masked_text)
58-
mlir_importer = SharkImporter(
59-
AlbertModule(),
60-
inputs,
61-
frontend="torch",
62-
)
63-
minilm_mlir, func_name = mlir_importer.import_mlir(
64-
is_dynamic=False, tracing_required=True
65-
)
66-
shark_module = SharkInference(
67-
minilm_mlir, func_name, mlir_dialect="linalg"
68-
)
69-
shark_module.compile()
70-
token_logits = torch.tensor(shark_module.forward(inputs))
60+
if COMPILE_MODULE == None:
61+
print("module compiled")
62+
mlir_importer = SharkImporter(
63+
AlbertModule(),
64+
inputs,
65+
frontend="torch",
66+
)
67+
minilm_mlir, func_name = mlir_importer.import_mlir(
68+
is_dynamic=False, tracing_required=True
69+
)
70+
shark_module = SharkInference(
71+
minilm_mlir, func_name, mlir_dialect="linalg", device="intel-gpu"
72+
)
73+
shark_module.compile()
74+
COMPILE_MODULE = shark_module
75+
76+
token_logits = torch.tensor(COMPILE_MODULE.forward(inputs))
7177
return top5_possibilities(masked_text, inputs, token_logits)

web/models/resnet50.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
################################## Preprocessing inputs and model ############
99

10+
COMPILE_MODULE = None
11+
1012

1113
def preprocess_image(img):
1214
image = Image.fromarray(img)
@@ -49,13 +51,19 @@ def top3_possibilities(res):
4951
def resnet_inf(numpy_img):
5052
img = preprocess_image(numpy_img)
5153
## Can pass any img or input to the forward module.
52-
mlir_model, func_name, inputs, golden_out = download_torch_model(
53-
"resnet50"
54-
)
54+
global COMPILE_MODULE
55+
if COMPILE_MODULE == None:
56+
mlir_model, func_name, inputs, golden_out = download_torch_model(
57+
"resnet50"
58+
)
59+
60+
shark_module = SharkInference(
61+
mlir_model, func_name, device="intel-gpu", mlir_dialect="linalg"
62+
)
63+
shark_module.compile()
64+
COMPILE_MODULE = shark_module
5565

56-
shark_module = SharkInference(mlir_model, func_name, mlir_dialect="linalg")
57-
shark_module.compile()
58-
result = shark_module.forward((img.detach().numpy(),))
66+
result = COMPILE_MODULE.forward((img.detach().numpy(),))
5967

6068
# print("The top 3 results obtained via shark_runner is:")
6169
return top3_possibilities(torch.from_numpy(result))

0 commit comments

Comments
 (0)