Skip to content

ggml : add IQ2 to test-backend-ops + refactoring #4990

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jan 17, 2024

Conversation

ggerganov
Copy link
Member

  • Add imatrix-based quantization tests to test-backend-ops (using a dummy imatrix of 1.0f weights)
  • Lazy quantization init API
  • Workaround for Apple linker bug (Fix MacOS Sonoma model quantization #4052)
  • Fix bug in CUDA mul_mat_vec_q when blocks_per_row % blocks_per_warp != 0 (out-of-bounds access)

Comment on lines 62 to 68
// when the imatrix is optional, we want to test both quantization with and without imatrix
std::random_device rd;
std::default_random_engine generator(rd());
std::uniform_int_distribution<int> distribution(0, 1);
if (distribution(generator)) {
im = nullptr;
}
Copy link
Collaborator

@cebtenzzre cebtenzzre Jan 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this code isn't critical, but creating a new random device and then a new random generator every time you want a random boolean is definitely an anti-pattern. You should at least make generator static. (In fact, creating random generators too often can exhaust the system's entropy source, e.g. /dev/random.)

And do you think we should use a fixed seed by default, or is this designed to be a non-deterministic test?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw init_thread is doing the same thing.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For init_thread I'm not sure how to improve - would static thread_local achieve anything? Simply static is a data race AFAIK

Either way - it's OK as it is for this part of the code

Copy link
Collaborator

@cebtenzzre cebtenzzre Jan 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it were thread-local you would have to initialize it separately, which would be ugly. If you wanted to improve it you could initialize the random generators at the top of init_tensor_uniform in advance like this (static so it happens only once):

static std::vector<std::default_random_engine> generators = [n_threads]() {
    std::random_device rd;
    std::vector<std::default_random_engine> vec;
    vec.reserve(n_threads);
    for (size_t i = 0; i < n_threads; i++) { vec.emplace_back(rd()); }
    return vec;
}();

And then pass them to init_thread by reference.

@ggerganov ggerganov added the sync Requires sync with the ggml repo after merging label Jan 17, 2024
@ggerganov ggerganov merged commit 3856668 into master Jan 17, 2024
jordankanter pushed a commit to jordankanter/llama.cpp that referenced this pull request Feb 3, 2024
* ggml : add IQ2 to test-backend-ops + refactoring

ggml-ci

* cuda : update supports_op for IQ2

ggml-ci

* ci : enable LLAMA_CUBLAS=1 for CUDA nodes

ggml-ci

* cuda : fix out-of-bounds-access in `mul_mat_vec_q`

ggml-ci

* tests : avoid creating RNGs for each Q tensor

ggml-ci

* tests : avoid creating RNGs for each tensor

ggml-ci
hodlen pushed a commit to hodlen/llama.cpp that referenced this pull request Apr 1, 2024
* ggml : add IQ2 to test-backend-ops + refactoring

ggml-ci

* cuda : update supports_op for IQ2

ggml-ci

* ci : enable LLAMA_CUBLAS=1 for CUDA nodes

ggml-ci

* cuda : fix out-of-bounds-access in `mul_mat_vec_q`

ggml-ci

* tests : avoid creating RNGs for each Q tensor

ggml-ci

* tests : avoid creating RNGs for each tensor

ggml-ci
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
sync Requires sync with the ggml repo after merging
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants