Skip to content

Commit 65bac0b

Browse files
committed
Added python api
1 parent 549a904 commit 65bac0b

File tree

6 files changed

+435
-21
lines changed

6 files changed

+435
-21
lines changed

python/sparknlp/annotator/seq2seq/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,4 @@
2525
from sparknlp.annotator.seq2seq.nllb_transformer import *
2626
from sparknlp.annotator.seq2seq.cpm_transformer import *
2727
from sparknlp.annotator.seq2seq.qwen_transformer import *
28+
from sparknlp.annotator.seq2seq.starcoder_transformer import *
Lines changed: 335 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,335 @@
1+
# Copyright 2017-2022 John Snow Labs
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Contains classes for the StarCoderTransformer."""
15+
16+
from sparknlp.common import *
17+
18+
19+
class StarCoderTransformer(AnnotatorModel, HasBatchedAnnotate, HasEngine):
20+
"""StarCoder2: The Versatile Code Companion.
21+
22+
StarCoder2 is a Transformer model designed specifically for code generation and understanding.
23+
With 13 billion parameters, it builds upon the advancements of its predecessors and is trained
24+
on a diverse dataset that includes multiple programming languages. This extensive training
25+
allows StarCoder2 to support a wide array of coding tasks, from code completion to generation.
26+
27+
StarCoder2 was developed to assist developers in writing and understanding code more efficiently,
28+
making it a valuable tool for various software development and data science tasks.
29+
30+
Pretrained models can be loaded with :meth:`.pretrained` of the companion
31+
object:
32+
33+
>>> starcoder2 = StarCoder2Transformer.pretrained() \\
34+
... .setInputCols(["document"]) \\
35+
... .setOutputCol("generation")
36+
37+
The default model is ``"starcoder2-13b"``, if no name is provided. For available
38+
pretrained models please see the `Models Hub
39+
<https://sparknlp.org/models?q=starcoder2>`__.
40+
41+
====================== ======================
42+
Input Annotation types Output Annotation type
43+
====================== ======================
44+
``DOCUMENT`` ``DOCUMENT``
45+
====================== ======================
46+
47+
Parameters
48+
----------
49+
configProtoBytes
50+
ConfigProto from tensorflow, serialized into byte array.
51+
minOutputLength
52+
Minimum length of the sequence to be generated, by default 0
53+
maxOutputLength
54+
Maximum length of output text, by default 20
55+
doSample
56+
Whether or not to use sampling; use greedy decoding otherwise, by default False
57+
temperature
58+
The value used to modulate the next token probabilities, by default 1.0
59+
topK
60+
The number of highest probability vocabulary tokens to keep for
61+
top-k-filtering, by default 50
62+
topP
63+
Top cumulative probability for vocabulary tokens, by default 1.0
64+
65+
If set to float < 1, only the most probable tokens with probabilities
66+
that add up to ``topP`` or higher are kept for generation.
67+
repetitionPenalty
68+
The parameter for repetition penalty, 1.0 means no penalty. , by default
69+
1.0
70+
noRepeatNgramSize
71+
If set to int > 0, all ngrams of that size can only occur once, by
72+
default 0
73+
ignoreTokenIds
74+
A list of token ids which are ignored in the decoder's output, by
75+
default []
76+
77+
Notes
78+
-----
79+
This is a very computationally expensive module especially on larger
80+
sequence. The use of an accelerator such as GPU is recommended.
81+
82+
References
83+
----------
84+
- `StarCoder2: The Versatile Code Companion.
85+
<https://huggingface.co/blog/starcoder>`__
86+
- https://github.com/bigcode-project/starcoder
87+
88+
**Paper Abstract:**
89+
90+
*The BigCode project, an open-scientific collaboration focused on the responsible
91+
development of Large Language Models for Code (Code LLMs), introduces StarCoder2. In
92+
partnership with Software Heritage (SWH), we build The Stack v2 on top of the digital commons
93+
of their source code archive. Alongside the SWH repositories spanning 619 programming
94+
languages, we carefully select other high-quality data sources, such as GitHub pull requests,
95+
Kaggle notebooks, and code documentation. This results in a training set that is 4× larger
96+
than the first StarCoder dataset. We train StarCoder2 models with 3B, 7B, and 15B parameters
97+
on 3.3 to 4.3 trillion tokens and thoroughly evaluate them on a comprehensive set of Code LLM
98+
benchmarks.*
99+
100+
*We find that our small model, StarCoder2-3B, outperforms other Code LLMs of similar size on
101+
most benchmarks, and also outperforms StarCoderBase-15B. Our large model, StarCoder2-15B,
102+
significantly outperforms other models of comparable size. In addition, it matches or
103+
outperforms CodeLlama-34B, a model more than twice its size. Although DeepSeekCoder-33B is
104+
the best-performing model at code completion for high-resource languages, we find that
105+
StarCoder2-15B outperforms it on math and code reasoning benchmarks, as well as several
106+
low-resource languages. We make the model weights available under an OpenRAIL license and
107+
ensure full transparency regarding the training data by releasing the Software Heritage
108+
persistent Identifiers (SWHIDs) of the source code data.*
109+
110+
Examples
111+
--------
112+
>>> import sparknlp
113+
>>> from sparknlp.base import *
114+
>>> from sparknlp.annotator import *
115+
>>> from pyspark.ml import Pipeline
116+
>>> documentAssembler = DocumentAssembler() \\
117+
... .setInputCol("text") \\
118+
... .setOutputCol("documents")
119+
>>> starcoder2 = StarCoder2Transformer.pretrained("starcoder2") \\
120+
... .setInputCols(["documents"]) \\
121+
... .setMaxOutputLength(50) \\
122+
... .setOutputCol("generation")
123+
>>> pipeline = Pipeline().setStages([documentAssembler, starcoder2])
124+
>>> data = spark.createDataFrame([["def add(a, b):"]]).toDF("text")
125+
>>> result = pipeline.fit(data).transform(data)
126+
>>> result.select("generation.result").show(truncate=False)
127+
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
128+
|result |
129+
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
130+
|[def add(a, b): return a + b] |
131+
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
132+
"""
133+
134+
135+
136+
name = "StarCoderTransformer"
137+
138+
inputAnnotatorTypes = [AnnotatorType.DOCUMENT]
139+
140+
outputAnnotatorType = AnnotatorType.DOCUMENT
141+
142+
configProtoBytes = Param(Params._dummy(), "configProtoBytes",
143+
"ConfigProto from tensorflow, serialized into byte array. Get with config_proto.SerializeToString()",
144+
TypeConverters.toListInt)
145+
146+
minOutputLength = Param(Params._dummy(), "minOutputLength", "Minimum length of the sequence to be generated",
147+
typeConverter=TypeConverters.toInt)
148+
149+
maxOutputLength = Param(Params._dummy(), "maxOutputLength", "Maximum length of output text",
150+
typeConverter=TypeConverters.toInt)
151+
152+
doSample = Param(Params._dummy(), "doSample", "Whether or not to use sampling; use greedy decoding otherwise",
153+
typeConverter=TypeConverters.toBoolean)
154+
155+
temperature = Param(Params._dummy(), "temperature", "The value used to module the next token probabilities",
156+
typeConverter=TypeConverters.toFloat)
157+
158+
topK = Param(Params._dummy(), "topK",
159+
"The number of highest probability vocabulary tokens to keep for top-k-filtering",
160+
typeConverter=TypeConverters.toInt)
161+
162+
topP = Param(Params._dummy(), "topP",
163+
"If set to float < 1, only the most probable tokens with probabilities that add up to ``top_p`` or higher are kept for generation",
164+
typeConverter=TypeConverters.toFloat)
165+
166+
repetitionPenalty = Param(Params._dummy(), "repetitionPenalty",
167+
"The parameter for repetition penalty. 1.0 means no penalty. See `this paper <https://arxiv.org/pdf/1909.05858.pdf>`__ for more details",
168+
typeConverter=TypeConverters.toFloat)
169+
170+
noRepeatNgramSize = Param(Params._dummy(), "noRepeatNgramSize",
171+
"If set to int > 0, all ngrams of that size can only occur once",
172+
typeConverter=TypeConverters.toInt)
173+
174+
ignoreTokenIds = Param(Params._dummy(), "ignoreTokenIds",
175+
"A list of token ids which are ignored in the decoder's output",
176+
typeConverter=TypeConverters.toListInt)
177+
178+
def setIgnoreTokenIds(self, value):
179+
"""A list of token ids which are ignored in the decoder's output.
180+
181+
Parameters
182+
----------
183+
value : List[int]
184+
The words to be filtered out
185+
"""
186+
return self._set(ignoreTokenIds=value)
187+
188+
def setConfigProtoBytes(self, b):
189+
"""Sets configProto from tensorflow, serialized into byte array.
190+
191+
Parameters
192+
----------
193+
b : List[int]
194+
ConfigProto from tensorflow, serialized into byte array
195+
"""
196+
return self._set(configProtoBytes=b)
197+
198+
def setMinOutputLength(self, value):
199+
"""Sets minimum length of the sequence to be generated.
200+
201+
Parameters
202+
----------
203+
value : int
204+
Minimum length of the sequence to be generated
205+
"""
206+
return self._set(minOutputLength=value)
207+
208+
def setMaxOutputLength(self, value):
209+
"""Sets maximum length of output text.
210+
211+
Parameters
212+
----------
213+
value : int
214+
Maximum length of output text
215+
"""
216+
return self._set(maxOutputLength=value)
217+
218+
def setDoSample(self, value):
219+
"""Sets whether or not to use sampling, use greedy decoding otherwise.
220+
221+
Parameters
222+
----------
223+
value : bool
224+
Whether or not to use sampling; use greedy decoding otherwise
225+
"""
226+
return self._set(doSample=value)
227+
228+
def setTemperature(self, value):
229+
"""Sets the value used to module the next token probabilities.
230+
231+
Parameters
232+
----------
233+
value : float
234+
The value used to module the next token probabilities
235+
"""
236+
return self._set(temperature=value)
237+
238+
def setTopK(self, value):
239+
"""Sets the number of highest probability vocabulary tokens to keep for
240+
top-k-filtering.
241+
242+
Parameters
243+
----------
244+
value : int
245+
Number of highest probability vocabulary tokens to keep
246+
"""
247+
return self._set(topK=value)
248+
249+
def setTopP(self, value):
250+
"""Sets the top cumulative probability for vocabulary tokens.
251+
252+
If set to float < 1, only the most probable tokens with probabilities
253+
that add up to ``topP`` or higher are kept for generation.
254+
255+
Parameters
256+
----------
257+
value : float
258+
Cumulative probability for vocabulary tokens
259+
"""
260+
return self._set(topP=value)
261+
262+
def setRepetitionPenalty(self, value):
263+
"""Sets the parameter for repetition penalty. 1.0 means no penalty.
264+
265+
Parameters
266+
----------
267+
value : float
268+
The repetition penalty
269+
270+
References
271+
----------
272+
See `Ctrl: A Conditional Transformer Language Model For Controllable
273+
Generation <https://arxiv.org/pdf/1909.05858.pdf>`__ for more details.
274+
"""
275+
return self._set(repetitionPenalty=value)
276+
277+
def setNoRepeatNgramSize(self, value):
278+
"""Sets size of n-grams that can only occur once.
279+
280+
If set to int > 0, all ngrams of that size can only occur once.
281+
282+
Parameters
283+
----------
284+
value : int
285+
N-gram size can only occur once
286+
"""
287+
return self._set(noRepeatNgramSize=value)
288+
289+
@keyword_only
290+
def __init__(self, classname="com.johnsnowlabs.nlp.annotators.seq2seq.StarCoderTransformer", java_model=None):
291+
super(StarCoderTransformer, self).__init__(classname=classname, java_model=java_model)
292+
self._setDefault(minOutputLength=0, maxOutputLength=20, doSample=False, temperature=0.6, topK=50, topP=0.9,
293+
repetitionPenalty=1.0, noRepeatNgramSize=0, ignoreTokenIds=[], batchSize=1)
294+
295+
@staticmethod
296+
def loadSavedModel(folder, spark_session, use_openvino=False):
297+
"""Loads a locally saved model.
298+
299+
Parameters
300+
----------
301+
folder : str
302+
Folder of the saved model
303+
spark_session : pyspark.sql.SparkSession
304+
The current SparkSession
305+
306+
Returns
307+
-------
308+
StarCoderTransformer
309+
The restored model
310+
"""
311+
from sparknlp.internal import _StarCoderLoader
312+
jModel = _StarCoderLoader(folder, spark_session._jsparkSession, use_openvino)._java_obj
313+
return StarCoderTransformer(java_model=jModel)
314+
315+
@staticmethod
316+
def pretrained(name="starcoder", lang="en", remote_loc=None):
317+
"""Downloads and loads a pretrained model.
318+
319+
Parameters
320+
----------
321+
name : str, optional
322+
Name of the pretrained model, by default "starcoder"
323+
lang : str, optional
324+
Language of the pretrained model, by default "en"
325+
remote_loc : str, optional
326+
Optional remote address of the resource, by default None. Will use
327+
Spark NLPs repositories otherwise.
328+
329+
Returns
330+
-------
331+
StarCoderTransformer
332+
The restored model
333+
"""
334+
from sparknlp.pretrained import ResourceDownloader
335+
return ResourceDownloader.downloadModel(StarCoderTransformer, name, lang, remote_loc)

python/sparknlp/internal/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,15 @@ def __init__(self, path, jspark):
394394
)
395395

396396

397+
class _StarCoderLoader(ExtendedJavaWrapper):
398+
def __init__(self, path, jspark, use_openvino=False):
399+
super(_StarCoderLoader, self).__init__(
400+
"com.johnsnowlabs.nlp.annotators.seq2seq.StarCoderTransformer.loadSavedModel",
401+
path,
402+
jspark,
403+
use_openvino,
404+
)
405+
397406
class _T5Loader(ExtendedJavaWrapper):
398407
def __init__(self, path, jspark):
399408
super(_T5Loader, self).__init__(

0 commit comments

Comments
 (0)