Skip to content

Conversation

jianan-gu
Copy link
Contributor

What does this PR do?

This PR supports the feature request #17137

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented May 9, 2022

The documentation is not available anymore as the PR was closed or merged.

@jgong5
Copy link

jgong5 commented May 9, 2022

cc @stas00

Copy link
Contributor

@stas00 stas00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the detailed proposal and the implementation, @jianan-gu!

Other than the few comments/suggestions I left in the code the next step is we need tests that exercise this new feature.

The failing run_tests_torch_and_flax CI has a flakey test - please ignore.

@stas00 stas00 requested a review from sgugger May 9, 2022 21:05
@sgugger
Copy link
Collaborator

sgugger commented May 10, 2022

Like #17153 I'm not entirely sure if this is best put here or in Optimum, so would like to hear back from @mfuntowicz and @LysandreJik before taking a deeper dive in the PR and review :-)

@jianan-gu
Copy link
Contributor Author

Hi, @mfuntowicz @LysandreJik @sgugger , any suggestion about this PR? We are happy to have a discussion here on any concerns.

Currently, BF16 AMP is already supported for GPU in Transformers, and with this PR we could extend that on the CPU side both for inference and training. Furtherly, users could get performance benefits by using IPEX with just one args "--use_ipex".

Thanks.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this. Before merging this we would need to have some tests of the feature. I also have a few comments on the PR.

Comment on lines 161 to 162
if is_ipex_available() and "--use_ipex" in sys.argv:
import intel_extension_for_pytorch as ipex
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't rely on sys.argv as the arguments may be passed to the Trainer class as TrainingArguments. This import should be done inside the Trainer as needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, and have moved this importation to the inside ipex_optimize_model

if version.parse(torch.__version__) >= version.parse("1.6"):
_is_torch_generator_available = True
_is_native_amp_available = True
from torch.cuda.amp import autocast
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this import removed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because there are two kinds of autocast, for CPU and CUDA. Since torch is imported, we directly use torch.cpu.amp.autocast and torch.cuda.amp.autocast in contextmgr.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for explaining!

else:
if args.bf16:
raise ValueError("Tried to use `bf16` but native amp is not available")
if args.device == torch.device("cuda"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is completely unreliable and will be false if args.device is set with torch.device(0) or when there are multiple GPUs. It should therefore be removed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, have made changes to check if it is CPU or others ( GPU, multiple GPU...), thanks.

return model

def ipex_optimize_model(self, model, training=False, dtype=torch.float32):
if not training:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like the ipex import can go here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have made changes accordingly, thanks.

Copy link
Contributor

@stas00 stas00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The doc looks great - thank you, @jianan-gu!

Added one small grammar fix

Would you like to add a benchmark to show the speed with and without ipex? otherwise it's very hard to tell whether it's worth the effort trying it.

as I mentioned earlier you'd just run the trainer twice w/ and w/o ipex and make sure to add --skip_memory_metrics 0 and it'll print you the speed at which it finished each run.

If it's not too much trouble that is. I'm asking since you probably have done a lot of experiments already and saw the power of this extension, while none of us did.

@jianan-gu
Copy link
Contributor Author

The doc looks great - thank you, @jianan-gu!

Added one small grammar fix

Would you like to add a benchmark to show the speed with and without ipex? otherwise it's very hard to tell whether it's worth the effort trying it.

as I mentioned earlier you'd just run the trainer twice w/ and w/o ipex and make sure to add --skip_memory_metrics 0 and it'll print you the speed at which it finished each run.

If it's not too much trouble that is. I'm asking since you probably have done a lot of experiments already and saw the power of this extension, while none of us did.

Hi, thanks for your suggestions and reply.
For the benchmark that shows the speed with and without ipex, yes, we could do that and present the speedup results there.
Also, we have done a lot of experiments like Bert training, and we will provide the performance gains but it would take some time (may cost 1-2 weeks) to go through the internal process to make them publicly available.
Besides, is it good enough if we demonstrate the speedup with relative performance gain (like the table here)? Thanks

@stas00
Copy link
Contributor

stas00 commented Jun 7, 2022

Besides, is it good enough if we demonstrate the speedup with relative performance gain (like the table here)?

I don't think so. Those performance numbers can't be reproduced by a user, so they are very practical.

As I suggested there is no need for any major benchmarks, here is a simple eval one:

# no gpus / no ipex 
PYTHONPATH=src USE_TF=0 CUDA_VISIBLE_DEVICES= python \
examples/pytorch/translation/run_translation.py --model_name_or_path t5-small \
--output_dir output_dir --adam_eps 1e-06 --do_eval --evaluation_strategy=steps \
--label_smoothing 0.1 --learning_rate 3e-5 --logging_first_step \
--logging_steps 500 --max_source_length 128 --max_target_length 128 \
--num_train_epochs 1 --overwrite_output_dir --per_device_eval_batch_size 16 \
--predict_with_generate --eval_steps 150 --sortish_sampler --source_lang en \
--target_lang ro --dataset_name wmt16 --dataset_config ro-en --source_prefix \
'translate English to Romanian: ' --val_max_target_length 128 --warmup_steps \
50 --max_eval_samples 500 --skip_memory_metrics 0 --bf16
[...]
***** eval metrics *****
  before_init_mem_cpu       =     1000MB
  eval_bleu                 =    24.1261
  eval_gen_len              =     39.554
  eval_loss                 =      3.786
  eval_mem_cpu_alloc_delta  =      279MB
  eval_mem_cpu_peaked_delta =      437MB
  eval_runtime              = 0:01:53.04
  eval_samples              =        500
  eval_samples_per_second   =      4.423
  eval_steps_per_second     =      0.283
  init_mem_cpu_alloc_delta  =        0MB
  init_mem_cpu_peaked_delta =        0MB
  
# no gpus / no ipex 
PYTHONPATH=src USE_TF=0 CUDA_VISIBLE_DEVICES= python \
examples/pytorch/translation/run_translation.py --model_name_or_path t5-small \
--output_dir output_dir --adam_eps 1e-06 --do_eval --evaluation_strategy=steps \
--label_smoothing 0.1 --learning_rate 3e-5 --logging_first_step \
--logging_steps 500 --max_source_length 128 --max_target_length 128 \
--num_train_epochs 1 --overwrite_output_dir --per_device_eval_batch_size 16 \
--predict_with_generate --eval_steps 150 --sortish_sampler --source_lang en \
--target_lang ro --dataset_name wmt16 --dataset_config ro-en --source_prefix \
'translate English to Romanian: ' --val_max_target_length 128 --warmup_steps \
50 --max_eval_samples 500 --skip_memory_metrics 0 --bf16 --no_cuda --use_ipex
***** eval metrics *****
  before_init_mem_cpu       =     1010MB
  eval_bleu                 =    24.1261
  eval_gen_len              =     39.554
  eval_loss                 =     3.7863
  eval_mem_cpu_alloc_delta  =      518MB
  eval_mem_cpu_peaked_delta =      439MB
  eval_runtime              = 0:01:34.41
  eval_samples              =        500
  eval_samples_per_second   =      5.296
  eval_steps_per_second     =      0.339
  init_mem_cpu_alloc_delta  =        0MB
  init_mem_cpu_peaked_delta =        0MB

So 5.296 vs 4.423 eval_samples_per_second on my machine. ~16% speedup - that's a good start.

model name      : 11th Gen Intel(R) Core(TM) i7-11700K @ 3.60GHz

Probably would be better to test with a slightly larger model, as it is likely to get better speedup but it helps to see that it actually works!

This is good enough for a doc.

And down the road we can surely boost it with better more indepth benchmarks.

What do you think?

@stas00
Copy link
Contributor

stas00 commented Jun 7, 2022

Low precision data type BFloat16 has been natively supported on the 3rd Generation Xeon® Scalable Processors (aka Cooper Lake) with AVX512 instruction set and will be supported on the next generation of Intel® Xeon® Scalable Processors with Intel® Advanced Matrix Extensions (Intel® AMX) instruction set with further boosted performance.

Given your description - is there a way to check that ipex is actually doing anything, other than being installed and import'able?

i.e. currently it's just:

def is_ipex_available():
    return importlib.util.find_spec("intel_extension_for_pytorch") is not None

But if user's CPU is not Intel or not the right kind of Intel should it tell the user it's not supported and assert? or is it already the case? I guess mine is the right CPU so I it doesn't fail.

@jgong5
Copy link

jgong5 commented Jun 7, 2022

But if user's CPU is not Intel or not the right kind of Intel should it tell the user it's not supported and assert? or is it already the case? I guess mine is the right CPU so I it doesn't fail.

Hi @stas00 The optimization in IPEX is a superset of BF16. It optimizes for CPU with AVX-512 or above. It can also functionally work for CPUs with AVX2. So, I expect it also functionally works for AMD CPUs (even though we do not benchmark perf on AMD CPUs). IPEX pre-compiles multiple kernels for various CPU ISAs for an op ahead-of-time and does runtime kernel dispatch according to the underlying CPU ISA capability. IPEX doesn't explicitly check CPU types internally. With this, we are open to whether and what additional checks are needed to add in this PR.

@jianan-gu
Copy link
Contributor Author

Besides, is it good enough if we demonstrate the speedup with relative performance gain (like the table here)?

I don't think so. Those performance numbers can't be reproduced by a user, so they are very practical.

As I suggested there is no need for any major benchmarks, here is a simple eval one:

# no gpus / no ipex 
PYTHONPATH=src USE_TF=0 CUDA_VISIBLE_DEVICES= python \
examples/pytorch/translation/run_translation.py --model_name_or_path t5-small \
--output_dir output_dir --adam_eps 1e-06 --do_eval --evaluation_strategy=steps \
--label_smoothing 0.1 --learning_rate 3e-5 --logging_first_step \
--logging_steps 500 --max_source_length 128 --max_target_length 128 \
--num_train_epochs 1 --overwrite_output_dir --per_device_eval_batch_size 16 \
--predict_with_generate --eval_steps 150 --sortish_sampler --source_lang en \
--target_lang ro --dataset_name wmt16 --dataset_config ro-en --source_prefix \
'translate English to Romanian: ' --val_max_target_length 128 --warmup_steps \
50 --max_eval_samples 500 --skip_memory_metrics 0 --bf16
[...]
***** eval metrics *****
  before_init_mem_cpu       =     1000MB
  eval_bleu                 =    24.1261
  eval_gen_len              =     39.554
  eval_loss                 =      3.786
  eval_mem_cpu_alloc_delta  =      279MB
  eval_mem_cpu_peaked_delta =      437MB
  eval_runtime              = 0:01:53.04
  eval_samples              =        500
  eval_samples_per_second   =      4.423
  eval_steps_per_second     =      0.283
  init_mem_cpu_alloc_delta  =        0MB
  init_mem_cpu_peaked_delta =        0MB
  
# no gpus / no ipex 
PYTHONPATH=src USE_TF=0 CUDA_VISIBLE_DEVICES= python \
examples/pytorch/translation/run_translation.py --model_name_or_path t5-small \
--output_dir output_dir --adam_eps 1e-06 --do_eval --evaluation_strategy=steps \
--label_smoothing 0.1 --learning_rate 3e-5 --logging_first_step \
--logging_steps 500 --max_source_length 128 --max_target_length 128 \
--num_train_epochs 1 --overwrite_output_dir --per_device_eval_batch_size 16 \
--predict_with_generate --eval_steps 150 --sortish_sampler --source_lang en \
--target_lang ro --dataset_name wmt16 --dataset_config ro-en --source_prefix \
'translate English to Romanian: ' --val_max_target_length 128 --warmup_steps \
50 --max_eval_samples 500 --skip_memory_metrics 0 --bf16 --no_cuda --use_ipex
***** eval metrics *****
  before_init_mem_cpu       =     1010MB
  eval_bleu                 =    24.1261
  eval_gen_len              =     39.554
  eval_loss                 =     3.7863
  eval_mem_cpu_alloc_delta  =      518MB
  eval_mem_cpu_peaked_delta =      439MB
  eval_runtime              = 0:01:34.41
  eval_samples              =        500
  eval_samples_per_second   =      5.296
  eval_steps_per_second     =      0.339
  init_mem_cpu_alloc_delta  =        0MB
  init_mem_cpu_peaked_delta =        0MB

So 5.296 vs 4.423 eval_samples_per_second on my machine. ~16% speedup - that's a good start.

model name      : 11th Gen Intel(R) Core(TM) i7-11700K @ 3.60GHz

Probably would be better to test with a slightly larger model, as it is likely to get better speedup but it helps to see that it actually works!

This is good enough for a doc.

And down the road we can surely boost it with better more indepth benchmarks.

What do you think?

Sure, got it, so we will prepare for benchmarking example models to collect performance numbers (like you showed above) and the internal review process. And after that, we will update the doc with those performance numbers.
Thanks.

@stas00
Copy link
Contributor

stas00 commented Jun 7, 2022

Hi @stas00 The optimization in IPEX is a superset of BF16. It optimizes for CPU with AVX-512 or above. It can also functionally work for CPUs with AVX2. So, I expect it also functionally works for AMD CPUs (even though we do not benchmark perf on AMD CPUs). IPEX pre-compiles multiple kernels for various CPU ISAs for an op ahead-of-time and does runtime kernel dispatch according to the underlying CPU ISA capability. IPEX doesn't explicitly check CPU types internally. With this, we are open to whether and what additional checks are needed to add in this PR.

Perhaps mention it in the doc? just briefly - as in AMD CPUs and older Intel CPUs are likely to result in a better performance as well under IPEX.

Sure, got it, so we will prepare for benchmarking example models to collect performance numbers (like you showed above) and the internal review process. And after that, we will update the doc with those performance numbers.

I'd be perfectly happy to merge this sooner not to press you if you kindly commit to adding at least some benchmarks later in a new PR - will that be a good arrangement for you guys?

@jgong5
Copy link

jgong5 commented Jun 8, 2022

Hi @stas00

Perhaps mention it in the doc? just briefly - as in AMD CPUs and older Intel CPUs are likely to result in a better performance as well under IPEX.

Thanks for the suggestions. We revised the doc accordingly. Could you please review?

I'd be perfectly happy to merge this sooner not to press you if you kindly commit to adding at least some benchmarks later in a new PR - will that be a good arrangement for you guys?

Sure, sounds good to us. We will update the perf numbers in a separate PR as soon as they are ready.

Copy link
Contributor

@stas00 stas00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good now, @jianan-gu

@sgugger - could you please have another look - just the docs and then it's good to merge. Thank you!

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks a lot for iterating on this!

@sgugger sgugger merged commit 34097b3 into huggingface:main Jun 8, 2022
elusenji pushed a commit to elusenji/transformers that referenced this pull request Jun 12, 2022
…el Extension for PyTorch (huggingface#17138)

* init PR

* fix import ipex

* minor fix on bf16

* refine optimizer

* refine args notes

* refine code

* refine ipex optimize args

* refine half_precision_backend

* black format

* isort format

* isort format files

* flake8 format

* doc builder format

* refine codes

* remove jit and optim bits

* black preview format

* Update src/transformers/trainer.py

Co-authored-by: Sylvain Gugger <[email protected]>

* refine code

* refine notes

* Update src/transformers/trainer.py

Co-authored-by: Sylvain Gugger <[email protected]>

* Update src/transformers/trainer.py

Co-authored-by: Sylvain Gugger <[email protected]>

* code refine

* add ipex ut

* add performance cpu doc

* link to the cpu doc from main perf doc

* install ipex into CI's docker

* Update perf_train_cpu.mdx

* Update docs/source/en/perf_train_cpu.mdx

Co-authored-by: Stas Bekman <[email protected]>

* Update perf_train_cpu.mdx

* Update perf_train_cpu.mdx

Co-authored-by: Sylvain Gugger <[email protected]>
Co-authored-by: Stas Bekman <[email protected]>
Co-authored-by: Stas Bekman <[email protected]>
amyeroberts pushed a commit to amyeroberts/transformers that referenced this pull request Jun 16, 2022
…el Extension for PyTorch (huggingface#17138)

* init PR

* fix import ipex

* minor fix on bf16

* refine optimizer

* refine args notes

* refine code

* refine ipex optimize args

* refine half_precision_backend

* black format

* isort format

* isort format files

* flake8 format

* doc builder format

* refine codes

* remove jit and optim bits

* black preview format

* Update src/transformers/trainer.py

Co-authored-by: Sylvain Gugger <[email protected]>

* refine code

* refine notes

* Update src/transformers/trainer.py

Co-authored-by: Sylvain Gugger <[email protected]>

* Update src/transformers/trainer.py

Co-authored-by: Sylvain Gugger <[email protected]>

* code refine

* add ipex ut

* add performance cpu doc

* link to the cpu doc from main perf doc

* install ipex into CI's docker

* Update perf_train_cpu.mdx

* Update docs/source/en/perf_train_cpu.mdx

Co-authored-by: Stas Bekman <[email protected]>

* Update perf_train_cpu.mdx

* Update perf_train_cpu.mdx

Co-authored-by: Sylvain Gugger <[email protected]>
Co-authored-by: Stas Bekman <[email protected]>
Co-authored-by: Stas Bekman <[email protected]>
@stas00
Copy link
Contributor

stas00 commented Jun 30, 2022

@jianan-gu, fyi, ipex fails with pt-1.12

tests/trainer/test_trainer.py::TrainerIntegrationTest::test_evaluate_with_ipex
(line 6)  ImportError: /usr/local/lib/python3.8/dist-packages/intel_extension_for_pytorch/lib/libintel-ext-pt-cpu.so: undefined symbol: _ZNK3c1010TensorImpl5sizesEv
tests/trainer/test_trainer.py::TrainerIntegrationTest::test_number_of_steps_in_training_with_ipex
(line 6)  ImportError: /usr/local/lib/python3.8/dist-packages/intel_extension_for_pytorch/lib/libintel-ext-pt-cpu.so: undefined symbol: _ZNK3c1010TensorImpl5sizesEv
tests/trainer/test_trainer.py::TrainerIntegrationTest::test_predict_with_ipex
(line 6)  ImportError: /usr/local/lib/python3.8/dist-packages/intel_extension_for_pytorch/lib/libintel-ext-pt-cpu.so: undefined symbol: _ZNK3c1010TensorImpl5sizesEv

we are disabling the tests for now, and please let us know when this is fixed, so that we can reenable those. Thank you!

@jianan-gu
Copy link
Contributor Author

jianan-gu commented Jul 8, 2022

@jianan-gu, fyi, ipex fails with pt-1.12

tests/trainer/test_trainer.py::TrainerIntegrationTest::test_evaluate_with_ipex
(line 6)  ImportError: /usr/local/lib/python3.8/dist-packages/intel_extension_for_pytorch/lib/libintel-ext-pt-cpu.so: undefined symbol: _ZNK3c1010TensorImpl5sizesEv
tests/trainer/test_trainer.py::TrainerIntegrationTest::test_number_of_steps_in_training_with_ipex
(line 6)  ImportError: /usr/local/lib/python3.8/dist-packages/intel_extension_for_pytorch/lib/libintel-ext-pt-cpu.so: undefined symbol: _ZNK3c1010TensorImpl5sizesEv
tests/trainer/test_trainer.py::TrainerIntegrationTest::test_predict_with_ipex
(line 6)  ImportError: /usr/local/lib/python3.8/dist-packages/intel_extension_for_pytorch/lib/libintel-ext-pt-cpu.so: undefined symbol: _ZNK3c1010TensorImpl5sizesEv

we are disabling the tests for now, and please let us know when this is fixed, so that we can reenable those. Thank you!

Hi @stas00
Thanks for the information.
As mentioned in this issue, IPEX 1.12 release is available , and we also open a PR #18072 to enhance the integration for this version mismatch issue to avoid breaking Trainer;

Besides, for the performance data we discussed in this PR, we were mostly working on the IPEX 1.12 release during the past weeks, and we would like to prepare the data based on the new release. It would take some time, thanks.

@Oxi84
Copy link

Oxi84 commented Jan 18, 2023

I do not understand where is this implementation added?

import transformers
from transformers import T5ForConditionalGeneration,T5Tokenizer,T5TokenizerFast
model1a = T5ForConditionalGeneration.from_pretrained('t5-base',low_cpu_mem_usage=True)
tokenizer1 = T5TokenizerFast.from_pretrained('t5-base')

@stas00
Copy link
Contributor

stas00 commented Jan 18, 2023

@Oxi84, please see: https://huggingface.co/docs/transformers/perf_train_cpu#mixed-precision-with-ipex

It's integrated into HF Trainer.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants