Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
4572bd6
add jit mode option and model wrap
jianan-gu May 10, 2022
ff22924
Merge branch 'main' into Introduce_Jit
jianan-gu May 11, 2022
b7d87d4
Update src/transformers/training_args.py
jianan-gu May 20, 2022
7be2f3f
Update src/transformers/training_args.py
jianan-gu May 20, 2022
74a806e
Merge branch 'huggingface:main' into Introduce_Jit
jianan-gu May 20, 2022
8ab4ad3
refine code
jianan-gu May 20, 2022
fc197cc
Update src/transformers/trainer.py
jianan-gu May 23, 2022
1659d18
Update src/transformers/trainer.py
jianan-gu May 23, 2022
e8bd011
add ut and refine code
jianan-gu Jun 9, 2022
f516fdb
Merge branch 'main' into Introduce_Jit
jianan-gu Jun 9, 2022
e1727ad
code refine
jianan-gu Jun 9, 2022
2d7b0a8
refine code
jianan-gu Jun 9, 2022
11ffcda
add inference doc
stas00 Jun 9, 2022
8a07c6f
Update src/transformers/trainer.py
jianan-gu Jun 10, 2022
c5ff4ae
Update src/transformers/trainer.py
jianan-gu Jun 10, 2022
5297c70
add cpu inference performance doc
jianan-gu Jun 10, 2022
1351b97
Update perf_infer_cpu.mdx
jianan-gu Jun 10, 2022
973375f
Update perf_infer_cpu.mdx
jianan-gu Jun 13, 2022
5f60fe7
Update performance.mdx
jianan-gu Jun 13, 2022
aff4ac5
Update _toctree.yml
jianan-gu Jun 13, 2022
e96d9cf
refine jit func naming
jianan-gu Jun 13, 2022
ce8b230
Update _toctree.yml
jianan-gu Jun 13, 2022
04a36bd
Delete perf_infer_gpu_one.mdx
jianan-gu Jun 13, 2022
216b289
Update perf_infer_cpu.mdx
jianan-gu Jun 13, 2022
54cd559
Update docs/source/en/perf_infer_cpu.mdx
jianan-gu Jun 14, 2022
ae377f5
add none check before jit
jianan-gu Jun 14, 2022
3d78cb5
Update docs/source/en/perf_infer_cpu.mdx
jianan-gu Jun 14, 2022
ffdea68
Update docs/source/en/perf_infer_cpu.mdx
jianan-gu Jun 14, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@
title: Training on many GPUs
- local: perf_train_cpu
title: Training on CPU
- local: perf_infer_cpu
title: Inference on CPU
- local: perf_hardware
title: Custom hardware for training
- local: testing
Expand Down
57 changes: 57 additions & 0 deletions docs/source/en/perf_infer_cpu.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
<!--Copyright 2022 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
-->

# Efficient Inference on CPU

This guide focuses on inferencing large models efficiently on CPU.

## PyTorch JIT-mode (TorchScript)
TorchScript is a way to create serializable and optimizable models from PyTorch code. Any TorchScript program can be saved from a Python process and loaded in a process where there is no Python dependency.
Comparing to default eager mode, jit mode in PyTorch normally yields better performance for model inference from optimization methodologies like operator fusion.

For a gentle introduction to TorchScript, see the Introduction to [PyTorch TorchScript tutorial](https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html#tracing-modules).

### IPEX Graph Optimization with JIT-mode
Intel® Extension for PyTorch provides further optimizations in jit mode for Transformers series models. It is highly recommended for users to take advantage of Intel® Extension for PyTorch with jit mode. Some frequently used operator patterns from Transformers models are already supported in Intel® Extension for PyTorch with jit mode fusions. Those fusion patterns like Multi-head-attention fusion, Concat Linear, Linear+Add, Linear+Gelu, Add+LayerNorm fusion and etc. are enabled and perform well. The benefit of the fusion is delivered to users in a transparent fashion. According to the analysis, ~70% of most popular NLP tasks in question-answering, text-classification, and token-classification can get performance benefits with these fusion patterns for both Float32 precision and BFloat16 Mixed precision.

Check more detailed information for [IPEX Graph Optimization](https://intel.github.io/intel-extension-for-pytorch/1.11.200/tutorials/features/graph_optimization.html).

#### IPEX installation:

IPEX release is following PyTorch, check the approaches for [IPEX installation](https://intel.github.io/intel-extension-for-pytorch/).

### Usage of JIT-mode
To enable jit mode in Trainer, users should add `jit_mode_eval` in Trainer command arguments.

Take an example of the use cases on [Transformers question-answering](https://github.com/huggingface/transformers/tree/main/examples/pytorch/question-answering)

- Inference using jit mode on CPU:
<pre>python run_qa.py \
--model_name_or_path csarron/bert-base-uncased-squad-v1 \
--dataset_name squad \
--do_eval \
--max_seq_length 384 \
--doc_stride 128 \
--output_dir /tmp/ \
--no_cuda \
<b>--jit_mode_eval </b></pre>

- Inference with IPEX using jit mode on CPU:
<pre>python run_qa.py \
--model_name_or_path csarron/bert-base-uncased-squad-v1 \
--dataset_name squad \
--do_eval \
--max_seq_length 384 \
--doc_stride 128 \
--output_dir /tmp/ \
--no_cuda \
<b>--use_ipex \</b>
<b>--jit_mode_eval</b></pre>
2 changes: 1 addition & 1 deletion docs/source/en/performance.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ Efficient inference with large models in a production environment can be as chal

### CPU

_Coming soon_
[Go to CPU inference section](perf_infer_cpu.mdx)

### Single GPU

Expand Down
32 changes: 29 additions & 3 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1158,6 +1158,29 @@ def call_model_init(self, trial=None):

return model

def torch_jit_model_eval(self, model, dataloader, training=False):
if not training:
if dataloader is None:
logger.warning("failed to use PyTorch jit mode due to current dataloader is none.")
return model
jit_inputs = []
example_batch = next(iter(dataloader))
for key in example_batch:
example_tensor = torch.ones_like(example_batch[key])
jit_inputs.append(example_tensor)
jit_inputs = tuple(jit_inputs)
try:
jit_model = model.eval()
with ContextManagers([self.autocast_smart_context_manager(), torch.no_grad()]):
jit_model = torch.jit.trace(jit_model, jit_inputs, strict=False)
jit_model = torch.jit.freeze(jit_model)
jit_model(**example_batch)
model = jit_model
except (RuntimeError, TypeError) as e:
logger.warning(f"failed to use PyTorch jit mode due to: {e}.")

return model

def ipex_optimize_model(self, model, training=False, dtype=torch.float32):
if not is_ipex_available():
raise ImportError(
Expand All @@ -1177,11 +1200,14 @@ def ipex_optimize_model(self, model, training=False, dtype=torch.float32):

return model

def _wrap_model(self, model, training=True):
def _wrap_model(self, model, training=True, dataloader=None):
if self.args.use_ipex:
dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32
model = self.ipex_optimize_model(model, training, dtype=dtype)

if self.args.jit_mode_eval:
model = self.torch_jit_model_eval(model, dataloader, training)

if is_sagemaker_mp_enabled():
# Wrapping the base model twice in a DistributedModel will raise an error.
if isinstance(self.model_wrapped, smp.model.DistributedModel):
Expand Down Expand Up @@ -2688,7 +2714,7 @@ def evaluation_loop(
self.model_wrapped = deepspeed_engine
self.deepspeed = deepspeed_engine

model = self._wrap_model(self.model, training=False)
model = self._wrap_model(self.model, training=False, dataloader=dataloader)

# if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
# while ``train`` is running, cast it to the right dtype first and then put on device
Expand Down Expand Up @@ -3249,7 +3275,7 @@ def prediction_loop(
deepspeed_engine.optimizer.optimizer = None
deepspeed_engine.lr_scheduler = None

model = self._wrap_model(self.model, training=False)
model = self._wrap_model(self.model, training=False, dataloader=dataloader)

# if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
# while ``train`` is running, cast it to the right dtype first and then put on device
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,8 @@ class TrainingArguments:
Random seed to be used with data samplers. If not set, random generators for data sampling will use the
same seed as `seed`. This can be used to ensure reproducibility of data sampling, independent of the model
seed.
jit_mode_eval (`bool`, *optional*, defaults to `False`):
Whether or not to use PyTorch jit trace for inference.
use_ipex (`bool`, *optional*, defaults to `False`):
Use Intel extension for PyTorch when it is available. [IPEX
installation](https://github.com/intel/intel-extension-for-pytorch).
Expand Down Expand Up @@ -610,6 +612,9 @@ class TrainingArguments:
no_cuda: bool = field(default=False, metadata={"help": "Do not use CUDA even when it is available"})
seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
data_seed: Optional[int] = field(default=None, metadata={"help": "Random seed to be used with data samplers."})
jit_mode_eval: bool = field(
default=False, metadata={"help": "Whether or not to use PyTorch jit trace for inference"}
)
use_ipex: bool = field(
default=False,
metadata={
Expand Down
75 changes: 75 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,6 +844,47 @@ def test_evaluate(self):
expected_acc = AlmostAccuracy()((pred + 1, y))["accuracy"]
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)

def test_evaluate_with_jit(self):
trainer = get_regression_trainer(a=1.5, b=2.5, compute_metrics=AlmostAccuracy(), jit_mode_eval=True)
results = trainer.evaluate()

x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0]
pred = 1.5 * x + 2.5
expected_loss = ((pred - y) ** 2).mean()
self.assertAlmostEqual(results["eval_loss"], expected_loss)
expected_acc = AlmostAccuracy()((pred, y))["accuracy"]
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)

# With a number of elements not a round multiple of the batch size
trainer = get_regression_trainer(
a=1.5, b=2.5, eval_len=66, compute_metrics=AlmostAccuracy(), jit_mode_eval=True
)
results = trainer.evaluate()

x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0]
pred = 1.5 * x + 2.5
expected_loss = ((pred - y) ** 2).mean()
self.assertAlmostEqual(results["eval_loss"], expected_loss)
expected_acc = AlmostAccuracy()((pred, y))["accuracy"]
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)

# With logits preprocess
trainer = get_regression_trainer(
a=1.5,
b=2.5,
compute_metrics=AlmostAccuracy(),
preprocess_logits_for_metrics=lambda logits, labels: logits + 1,
jit_mode_eval=True,
)
results = trainer.evaluate()

x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0]
pred = 1.5 * x + 2.5
expected_loss = ((pred - y) ** 2).mean()
self.assertAlmostEqual(results["eval_loss"], expected_loss)
expected_acc = AlmostAccuracy()((pred + 1, y))["accuracy"]
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)

@require_torch_bf16
@require_intel_extension_for_pytorch
def test_evaluate_with_ipex(self):
Expand Down Expand Up @@ -930,6 +971,40 @@ def test_predict(self):
self.assertTrue(np.array_equal(labels[0], trainer.eval_dataset.ys[0]))
self.assertTrue(np.array_equal(labels[1], trainer.eval_dataset.ys[1]))

def test_predict_with_jit(self):
trainer = get_regression_trainer(a=1.5, b=2.5, jit_mode_eval=True)
preds = trainer.predict(trainer.eval_dataset).predictions
x = trainer.eval_dataset.x
self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))

# With a number of elements not a round multiple of the batch size
trainer = get_regression_trainer(a=1.5, b=2.5, eval_len=66, jit_mode_eval=True)
preds = trainer.predict(trainer.eval_dataset).predictions
x = trainer.eval_dataset.x
self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))

# With more than one output of the model
trainer = get_regression_trainer(a=1.5, b=2.5, double_output=True, jit_mode_eval=True)
preds = trainer.predict(trainer.eval_dataset).predictions
x = trainer.eval_dataset.x
self.assertEqual(len(preds), 2)
self.assertTrue(np.allclose(preds[0], 1.5 * x + 2.5))
self.assertTrue(np.allclose(preds[1], 1.5 * x + 2.5))

# With more than one output/label of the model
trainer = get_regression_trainer(
a=1.5, b=2.5, double_output=True, label_names=["labels", "labels_2"], jit_mode_eval=True
)
outputs = trainer.predict(trainer.eval_dataset)
preds = outputs.predictions
labels = outputs.label_ids
x = trainer.eval_dataset.x
self.assertEqual(len(preds), 2)
self.assertTrue(np.allclose(preds[0], 1.5 * x + 2.5))
self.assertTrue(np.allclose(preds[1], 1.5 * x + 2.5))
self.assertTrue(np.array_equal(labels[0], trainer.eval_dataset.ys[0]))
self.assertTrue(np.array_equal(labels[1], trainer.eval_dataset.ys[1]))

@require_torch_bf16
@require_intel_extension_for_pytorch
def test_predict_with_ipex(self):
Expand Down