Skip to content

Commit 44788dc

Browse files
authored
Merge pull request #4 from divya-kumari32/patch-1
Added overview and other additional details for Bamba
2 parents f7ceb0c + 0115bf6 commit 44788dc

File tree

1 file changed

+25
-32
lines changed

1 file changed

+25
-32
lines changed

docs/source/en/model_doc/bamba.md

Lines changed: 25 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -16,51 +16,44 @@ rendered properly in your Markdown viewer.
1616

1717
# Bamba
1818

19-
## Overview
2019

21-
TODO
20+
## Overview
2221

23-
Tips:
22+
Bamba-9B is a decoder-only language model based on the [Mamba-2](https://github.com/state-spaces/mamba) architecture and is designed to handle a wide range of text generation tasks. It is trained from scratch using a two-stage training approach. In the first stage, the model is trained on 2 trillion tokens from the Dolma v1.7 dataset. In the second stage, it undergoes additional training on 200 billion tokens, leveraging a carefully curated blend of high-quality data to further refine its performance and enhance output quality.
2423

25-
```python
26-
import torch
27-
from transformers import AutoModelForCausalLM, AutoTokenizer
24+
Checkout all Bamba-9B model checkpoints [here](https://github.com/foundation-model-stack/bamba).
2825

29-
model_path = "..."
30-
tokenizer = AutoTokenizer.from_pretrained(model_path)
31-
32-
# drop device_map if running on CPU
33-
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
34-
model.eval()
35-
36-
# change input text as desired
37-
prompt = "Write a code to find the maximum value in a list of numbers."
38-
39-
# tokenize the text
40-
input_tokens = tokenizer(prompt, return_tensors="pt")
41-
# generate output tokens
42-
output = model.generate(**input_tokens, max_new_tokens=100)
43-
# decode output tokens into text
44-
output = tokenizer.batch_decode(output)
45-
# loop over the batch to print, in this example the batch size is 1
46-
for i in output:
47-
print(i)
48-
```
26+
## BambaConfig
4927

50-
<!-- update this -->
51-
This model was contributed by [ani300]https://github.com/ani300) and [fabianlim]https://github.com/fabianlim) .
28+
| Model | Params | # Layers | Hidden Dim. | Attention Heads | GQA | KV Heads | Context Length | Tied Embeddings |
29+
|-------------------|--------------|----------|-------------|-----------------|-----|----------|----------------|------------------|
30+
| Bamba | 9B (9.78B) | 32 | 4096 | 32 | Yes | 8 | 4096 | True |
5231

32+
<!---
33+
## Usage Tips
5334
54-
## BambaConfig
35+
Tips:
5536
56-
[[autodoc]] BambaConfig
37+
- The architecture is based on Mamba-2 models.
5738
5839
## BambaModel
5940
6041
[[autodoc]] BambaModel
6142
- forward
43+
-->
6244

6345
## BambaForCausalLM
6446

65-
[[autodoc]] BambaForCausalLM
66-
- forward
47+
```python
48+
from transformers import AutoModelForCausalLM, AutoTokenizer
49+
50+
model = AutoModelForCausalLM.from_pretrained("ibm-fms/Bamba-9B")
51+
tokenizer = AutoTokenizer.from_pretrained("ibm-fms/Bamba-9B")
52+
53+
message = ["I am an LLM and my name is "]
54+
inputs = tokenizer(message, return_tensors='pt', return_token_type_ids=False)
55+
response = model.generate(**inputs, max_new_tokens=100, do_sample=True, top_k=50, top_p=0.95)
56+
print(tokenizer.batch_decode(response, skip_special_tokens=True)[0])
57+
```
58+
59+
This HF implementation is contributed by [ani300](https://github.com/ani300) and [fabianlim](https://github.com/fabianlim).

0 commit comments

Comments
 (0)