Skip to content

Commit 945a9c3

Browse files
committed
feat: add batch model download and apply
1 parent 3f73957 commit 945a9c3

File tree

6 files changed

+387
-21
lines changed

6 files changed

+387
-21
lines changed

api/api.go

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ func App(configFile string, loader *model.ModelLoader, uploadLimitMB, threads, c
4848
}))
4949
}
5050

51-
cm := make(ConfigMerger)
51+
cm := NewConfigMerger()
5252
if err := cm.LoadConfigs(loader.ModelPath); err != nil {
5353
log.Error().Msgf("error loading config files: %s", err.Error())
5454
}
@@ -60,39 +60,51 @@ func App(configFile string, loader *model.ModelLoader, uploadLimitMB, threads, c
6060
}
6161

6262
if debug {
63-
for k, v := range cm {
64-
log.Debug().Msgf("Model: %s (config: %+v)", k, v)
63+
for _, v := range cm.ListConfigs() {
64+
cfg, _ := cm.GetConfig(v)
65+
log.Debug().Msgf("Model: %s (config: %+v)", v, cfg)
6566
}
6667
}
6768
// Default middleware config
6869
app.Use(recover.New())
6970
app.Use(cors.New())
7071

72+
// LocalAI API endpoints
73+
applier := newGalleryApplier(loader.ModelPath)
74+
applier.start(cm)
75+
app.Post("/localai/gallery/apply", applyModelGallery(loader.ModelPath, cm, applier.C))
76+
app.Get("/localai/gallery/op/status/:uid", getOpStatus(applier))
77+
7178
// openAI compatible API endpoint
79+
80+
// chat
7281
app.Post("/v1/chat/completions", chatEndpoint(cm, debug, loader, threads, ctxSize, f16))
7382
app.Post("/chat/completions", chatEndpoint(cm, debug, loader, threads, ctxSize, f16))
7483

84+
// edit
7585
app.Post("/v1/edits", editEndpoint(cm, debug, loader, threads, ctxSize, f16))
7686
app.Post("/edits", editEndpoint(cm, debug, loader, threads, ctxSize, f16))
7787

88+
// completion
7889
app.Post("/v1/completions", completionEndpoint(cm, debug, loader, threads, ctxSize, f16))
7990
app.Post("/completions", completionEndpoint(cm, debug, loader, threads, ctxSize, f16))
8091

92+
// embeddings
8193
app.Post("/v1/embeddings", embeddingsEndpoint(cm, debug, loader, threads, ctxSize, f16))
8294
app.Post("/embeddings", embeddingsEndpoint(cm, debug, loader, threads, ctxSize, f16))
83-
84-
// /v1/engines/{engine_id}/embeddings
85-
8695
app.Post("/v1/engines/:model/embeddings", embeddingsEndpoint(cm, debug, loader, threads, ctxSize, f16))
8796

97+
// audio
8898
app.Post("/v1/audio/transcriptions", transcriptEndpoint(cm, debug, loader, threads, ctxSize, f16))
8999

100+
// images
90101
app.Post("/v1/images/generations", imageEndpoint(cm, debug, loader, imageDir))
91102

92103
if imageDir != "" {
93104
app.Static("/generated-images", imageDir)
94105
}
95106

107+
// models
96108
app.Get("/v1/models", listModels(loader, cm))
97109
app.Get("/models", listModels(loader, cm))
98110

api/config.go

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"os"
88
"path/filepath"
99
"strings"
10+
"sync"
1011

1112
model "github.com/go-skynet/LocalAI/pkg/model"
1213
"github.com/gofiber/fiber/v2"
@@ -43,8 +44,16 @@ type TemplateConfig struct {
4344
Edit string `yaml:"edit"`
4445
}
4546

46-
type ConfigMerger map[string]Config
47+
type ConfigMerger struct {
48+
configs map[string]Config
49+
sync.Mutex
50+
}
4751

52+
func NewConfigMerger() *ConfigMerger {
53+
return &ConfigMerger{
54+
configs: make(map[string]Config),
55+
}
56+
}
4857
func ReadConfigFile(file string) ([]*Config, error) {
4958
c := &[]*Config{}
5059
f, err := os.ReadFile(file)
@@ -72,28 +81,51 @@ func ReadConfig(file string) (*Config, error) {
7281
}
7382

7483
func (cm ConfigMerger) LoadConfigFile(file string) error {
84+
cm.Lock()
85+
defer cm.Unlock()
7586
c, err := ReadConfigFile(file)
7687
if err != nil {
7788
return fmt.Errorf("cannot load config file: %w", err)
7889
}
7990

8091
for _, cc := range c {
81-
cm[cc.Name] = *cc
92+
cm.configs[cc.Name] = *cc
8293
}
8394
return nil
8495
}
8596

8697
func (cm ConfigMerger) LoadConfig(file string) error {
98+
cm.Lock()
99+
defer cm.Unlock()
87100
c, err := ReadConfig(file)
88101
if err != nil {
89102
return fmt.Errorf("cannot read config file: %w", err)
90103
}
91104

92-
cm[c.Name] = *c
105+
cm.configs[c.Name] = *c
93106
return nil
94107
}
95108

109+
func (cm ConfigMerger) GetConfig(m string) (Config, bool) {
110+
cm.Lock()
111+
defer cm.Unlock()
112+
v, exists := cm.configs[m]
113+
return v, exists
114+
}
115+
116+
func (cm ConfigMerger) ListConfigs() []string {
117+
cm.Lock()
118+
defer cm.Unlock()
119+
var res []string
120+
for k := range cm.configs {
121+
res = append(res, k)
122+
}
123+
return res
124+
}
125+
96126
func (cm ConfigMerger) LoadConfigs(path string) error {
127+
cm.Lock()
128+
defer cm.Unlock()
97129
files, err := ioutil.ReadDir(path)
98130
if err != nil {
99131
return err
@@ -106,7 +138,7 @@ func (cm ConfigMerger) LoadConfigs(path string) error {
106138
}
107139
c, err := ReadConfig(filepath.Join(path, file.Name()))
108140
if err == nil {
109-
cm[c.Name] = *c
141+
cm.configs[c.Name] = *c
110142
}
111143
}
112144

@@ -253,7 +285,7 @@ func readInput(c *fiber.Ctx, loader *model.ModelLoader, randomModel bool) (strin
253285
return modelFile, input, nil
254286
}
255287

256-
func readConfig(modelFile string, input *OpenAIRequest, cm ConfigMerger, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*Config, *OpenAIRequest, error) {
288+
func readConfig(modelFile string, input *OpenAIRequest, cm *ConfigMerger, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*Config, *OpenAIRequest, error) {
257289
// Load a config file if present after the model name
258290
modelConfig := filepath.Join(loader.ModelPath, modelFile+".yaml")
259291
if _, err := os.Stat(modelConfig); err == nil {
@@ -263,7 +295,7 @@ func readConfig(modelFile string, input *OpenAIRequest, cm ConfigMerger, loader
263295
}
264296

265297
var config *Config
266-
cfg, exists := cm[modelFile]
298+
cfg, exists := cm.GetConfig(modelFile)
267299
if !exists {
268300
config = &Config{
269301
OpenAIRequest: defaultRequest(modelFile),

api/gallery.go

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
package api
2+
3+
import (
4+
"fmt"
5+
"io/ioutil"
6+
"net/http"
7+
"sync"
8+
9+
"github.com/go-skynet/LocalAI/pkg/gallery"
10+
"github.com/gofiber/fiber/v2"
11+
"github.com/google/uuid"
12+
"gopkg.in/yaml.v3"
13+
)
14+
15+
type galleryOp struct {
16+
req ApplyGalleryModelRequest
17+
id string
18+
}
19+
20+
type galleryOpStatus struct {
21+
Error error `json:"error"`
22+
Processed bool `json:"processed"`
23+
Message string `json:"message"`
24+
}
25+
26+
type galleryApplier struct {
27+
modelPath string
28+
sync.Mutex
29+
C chan galleryOp
30+
statuses map[string]*galleryOpStatus
31+
}
32+
33+
func newGalleryApplier(modelPath string) *galleryApplier {
34+
return &galleryApplier{
35+
modelPath: modelPath,
36+
C: make(chan galleryOp),
37+
statuses: make(map[string]*galleryOpStatus),
38+
}
39+
}
40+
func (g *galleryApplier) updatestatus(s string, op *galleryOpStatus) {
41+
g.Lock()
42+
defer g.Unlock()
43+
g.statuses[s] = op
44+
}
45+
46+
func (g *galleryApplier) getstatus(s string) *galleryOpStatus {
47+
g.Lock()
48+
defer g.Unlock()
49+
50+
return g.statuses[s]
51+
}
52+
53+
func (g *galleryApplier) start(cm *ConfigMerger) {
54+
go func() {
55+
for op := range g.C {
56+
g.updatestatus(op.id, &galleryOpStatus{Message: "processing"})
57+
58+
updateError := func(e error) {
59+
g.updatestatus(op.id, &galleryOpStatus{Error: e, Processed: true})
60+
}
61+
// Send a GET request to the URL
62+
response, err := http.Get(op.req.URL)
63+
if err != nil {
64+
updateError(err)
65+
continue
66+
}
67+
defer response.Body.Close()
68+
69+
// Read the response body
70+
body, err := ioutil.ReadAll(response.Body)
71+
if err != nil {
72+
updateError(err)
73+
continue
74+
}
75+
76+
// Unmarshal YAML data into a Config struct
77+
var config gallery.Config
78+
err = yaml.Unmarshal(body, &config)
79+
if err != nil {
80+
updateError(fmt.Errorf("failed to unmarshal YAML: %v", err))
81+
continue
82+
}
83+
84+
if err := gallery.Apply(g.modelPath, &config); err != nil {
85+
updateError(err)
86+
continue
87+
}
88+
89+
// Reload models
90+
if err := cm.LoadConfigs(g.modelPath); err != nil {
91+
updateError(err)
92+
continue
93+
}
94+
95+
g.updatestatus(op.id, &galleryOpStatus{Processed: true, Message: "completed"})
96+
}
97+
}()
98+
}
99+
100+
// endpoints
101+
102+
type ApplyGalleryModelRequest struct {
103+
URL string `json:"url"`
104+
}
105+
106+
func getOpStatus(g *galleryApplier) func(c *fiber.Ctx) error {
107+
return func(c *fiber.Ctx) error {
108+
109+
status := g.getstatus(c.Params("uid"))
110+
if status == nil {
111+
return fmt.Errorf("could not find any status for ID")
112+
}
113+
114+
return c.JSON(status)
115+
}
116+
}
117+
118+
func applyModelGallery(modelPath string, cm *ConfigMerger, g chan galleryOp) func(c *fiber.Ctx) error {
119+
return func(c *fiber.Ctx) error {
120+
input := new(ApplyGalleryModelRequest)
121+
// Get input data from the request body
122+
if err := c.BodyParser(input); err != nil {
123+
return err
124+
}
125+
126+
uuid, err := uuid.NewUUID()
127+
if err != nil {
128+
return err
129+
}
130+
g <- galleryOp{
131+
req: *input,
132+
id: uuid.String(),
133+
}
134+
return c.JSON(struct {
135+
ID string `json:"uid"`
136+
}{ID: uuid.String()})
137+
}
138+
}

api/openai.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ func defaultRequest(modelFile string) OpenAIRequest {
142142
}
143143

144144
// https://platform.openai.com/docs/api-reference/completions
145-
func completionEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error {
145+
func completionEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error {
146146
return func(c *fiber.Ctx) error {
147147

148148
model, input, err := readInput(c, loader, true)
@@ -199,7 +199,7 @@ func completionEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader,
199199
}
200200

201201
// https://platform.openai.com/docs/api-reference/embeddings
202-
func embeddingsEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error {
202+
func embeddingsEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error {
203203
return func(c *fiber.Ctx) error {
204204
model, input, err := readInput(c, loader, true)
205205
if err != nil {
@@ -256,7 +256,7 @@ func embeddingsEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader,
256256
}
257257
}
258258

259-
func chatEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error {
259+
func chatEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error {
260260

261261
process := func(s string, req *OpenAIRequest, config *Config, loader *model.ModelLoader, responses chan OpenAIResponse) {
262262
ComputeChoices(s, req, config, loader, func(s string, c *[]Choice) {}, func(s string) bool {
@@ -378,7 +378,7 @@ func chatEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, thread
378378
}
379379
}
380380

381-
func editEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error {
381+
func editEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error {
382382
return func(c *fiber.Ctx) error {
383383
model, input, err := readInput(c, loader, true)
384384
if err != nil {
@@ -449,7 +449,7 @@ func editEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, thread
449449
450450
*
451451
*/
452-
func imageEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, imageDir string) func(c *fiber.Ctx) error {
452+
func imageEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, imageDir string) func(c *fiber.Ctx) error {
453453
return func(c *fiber.Ctx) error {
454454
m, input, err := readInput(c, loader, false)
455455
if err != nil {
@@ -574,7 +574,7 @@ func imageEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, image
574574
}
575575

576576
// https://platform.openai.com/docs/api-reference/audio/create
577-
func transcriptEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error {
577+
func transcriptEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error {
578578
return func(c *fiber.Ctx) error {
579579
m, input, err := readInput(c, loader, false)
580580
if err != nil {
@@ -641,7 +641,7 @@ func transcriptEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader,
641641
}
642642
}
643643

644-
func listModels(loader *model.ModelLoader, cm ConfigMerger) func(ctx *fiber.Ctx) error {
644+
func listModels(loader *model.ModelLoader, cm *ConfigMerger) func(ctx *fiber.Ctx) error {
645645
return func(c *fiber.Ctx) error {
646646
models, err := loader.ListModels()
647647
if err != nil {
@@ -655,7 +655,7 @@ func listModels(loader *model.ModelLoader, cm ConfigMerger) func(ctx *fiber.Ctx)
655655
dataModels = append(dataModels, OpenAIModel{ID: m, Object: "model"})
656656
}
657657

658-
for k := range cm {
658+
for _, k := range cm.ListConfigs() {
659659
if _, exists := mm[k]; !exists {
660660
dataModels = append(dataModels, OpenAIModel{ID: k, Object: "model"})
661661
}

0 commit comments

Comments
 (0)