Skip to content

Commit d56ae62

Browse files
committed
Multiple finetune models (#179)
* add finetune info to models tab * show current completion model in finetune tab * modal warning if checkpoint do not match with selected model * Revert "Revert "hidden => False for starcoder models (#170)"" This reverts commit 52203da. * fix model name for lora setup * fix table styles
1 parent 551d4b0 commit d56ae62

File tree

8 files changed

+100
-22
lines changed

8 files changed

+100
-22
lines changed

known_models_db/refact_known_models/huggingface.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,35 +31,35 @@
3131
},
3232
"starcoder/1b/base": {
3333
"backend": "transformers",
34-
"model_path": "bigcode/starcoderbase-1b",
34+
"model_path": "smallcloudai/starcoderbase-1b",
3535
"diff_scratchpad_class": "refact_scratchpads:ScratchpadPSM",
3636
"chat_scratchpad_class": None,
3737
"model_class_kwargs": {},
3838
"required_memory_mb": 6000,
3939
"T": 4096,
40-
"hidden": True,
40+
"hidden": False,
4141
"filter_caps": ["completion", "finetune"],
4242
},
4343
"starcoder/3b/base": {
4444
"backend": "transformers",
45-
"model_path": "bigcode/starcoderbase-3b",
45+
"model_path": "smallcloudai/starcoderbase-3b",
4646
"diff_scratchpad_class": "refact_scratchpads:ScratchpadPSM",
4747
"chat_scratchpad_class": None,
4848
"model_class_kwargs": {},
4949
"required_memory_mb": 9000,
5050
"T": 4096,
51-
"hidden": True,
51+
"hidden": False,
5252
"filter_caps": ["completion", "finetune"],
5353
},
5454
"starcoder/7b/base": {
5555
"backend": "transformers",
56-
"model_path": "bigcode/starcoderbase-7b",
56+
"model_path": "smallcloudai/starcoderbase-7b",
5757
"diff_scratchpad_class": "refact_scratchpads:ScratchpadPSM",
5858
"chat_scratchpad_class": None,
5959
"model_class_kwargs": {},
6060
"required_memory_mb": 18000,
6161
"T": 2048,
62-
"hidden": True,
62+
"hidden": False,
6363
"filter_caps": ["completion", "finetune"],
6464
},
6565
"wizardcoder/15b": {

self_hosting_machinery/webgui/selfhost_model_assigner.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from self_hosting_machinery.webgui.selfhost_webutils import log
99
from known_models_db.refact_known_models import models_mini_db
1010
from known_models_db.refact_toolbox_db import modelcap_records
11+
from self_hosting_machinery.scripts.best_lora import find_best_lora
12+
from refact_data_pipeline.finetune.finetune_utils import get_active_loras
1113

1214
from typing import List, Dict, Set, Any
1315

@@ -205,12 +207,28 @@ def _capabilities(func_type: str) -> Set:
205207

206208
chat_caps = _capabilities("chat")
207209
toolbox_caps = _capabilities("toolbox")
210+
active_loras = get_active_loras(self.models_db)
208211
for k, rec in self.models_db.items():
209212
if rec.get("hidden", False):
210213
continue
214+
finetune_info = None
215+
if k in active_loras:
216+
lora_mode = active_loras[k]["lora_mode"]
217+
latest_best_lora_info = find_best_lora(k)
218+
if lora_mode == "latest-best" and latest_best_lora_info["latest_run_id"]:
219+
finetune_info = {
220+
"run": latest_best_lora_info["latest_run_id"],
221+
"checkpoint": latest_best_lora_info["best_checkpoint_id"],
222+
}
223+
elif lora_mode == "specific" and active_loras[k].get("specific_lora_run_id", ""):
224+
finetune_info = {
225+
"run": active_loras[k]["specific_lora_run_id"],
226+
"checkpoint": active_loras[k]["specific_checkpoint"],
227+
}
211228
info.append({
212229
"name": k,
213230
"backend": rec["backend"],
231+
"finetune_info": finetune_info,
214232
"has_completion": bool("completion" in rec["filter_caps"]),
215233
"has_finetune": bool("finetune" in rec["filter_caps"]),
216234
"has_toolbox": bool(toolbox_caps.intersection(rec["filter_caps"])),

self_hosting_machinery/webgui/static/tab-finetune.html

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
<div class="col-5">
55
<div class="pane use-model-pane">
66
<h3>Use Finetuned Model</h3>
7+
<div class="lora-model"><div id="lora-switch-model">Model:</div></div>
78
<div class="lora-group">
89
<div class="btn-group" role="group" aria-label="basic radio toggle button group">
910
<input type="radio" class="lora-switch btn-check" name="finetune_lora" value="off" id="loraradio1"
@@ -387,4 +388,23 @@ <h5>Limit training time</h5>
387388
</div>
388389
</div>
389390

391+
<div class="modal fade" id="finetune-tab-model-warning-modal" tabindex="-1" aria-labelledby="finetune-tab-invalid-model-modal" aria-hidden="true">
392+
<div class="modal-dialog modal-lg modal-dialog-centered">
393+
<div class="modal-content">
394+
<div class="modal-header">
395+
<h5 class="modal-title" id="urlModalLabel">Warning</h5>
396+
<button type="button" class="btn-close" data-bs-dismiss="modal" aria-label="Close"></button>
397+
</div>
398+
<div class="modal-body">
399+
<div class="row">
400+
<div class="mb-3" id="model-warning-message"></div>
401+
</div>
402+
</div>
403+
<div class="modal-footer">
404+
<button type="button" class="btn btn-secondary" data-bs-dismiss="modal">Ok</button>
405+
</div>
406+
</div>
407+
</div>
408+
</div>
409+
390410
</div>

self_hosting_machinery/webgui/static/tab-finetune.js

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ function tab_finetune_config_and_runs() {
5454
return response.json();
5555
})
5656
.then(function (data) {
57-
console.log('tab-finetune-config-and-runs',data);
5857
finetune_configs_and_runs = data;
5958
render_runs();
6059
render_model_select();
@@ -274,7 +273,12 @@ const find_checkpoints_by_run = (run_id) => {
274273
};
275274

276275
function render_lora_switch() {
277-
let mode = finetune_configs_and_runs.active[finetune_configs_and_runs.config.model_name] ? finetune_configs_and_runs.active[finetune_configs_and_runs.config.model_name].lora_mode : "latest-best";
276+
const model_name = finetune_configs_and_runs.completion_model;
277+
let lora_switch_model = document.querySelector('#lora-switch-model');
278+
lora_switch_model.innerHTML = `
279+
<b>Model:</b> ${model_name}
280+
`;
281+
let mode = finetune_configs_and_runs.active[model_name] ? finetune_configs_and_runs.active[model_name].lora_mode : "latest-best";
278282
loras_switch_no_reaction = true; // avoid infinite loop when setting .checked
279283
if (mode === 'off') {
280284
loras_switch_off.checked = true;
@@ -291,8 +295,8 @@ function render_lora_switch() {
291295
lora_switch_checkpoint.style.display = 'block';
292296
lora_switch_run_id.style.opacity = 1;
293297
lora_switch_checkpoint.style.opacity = 1;
294-
lora_switch_run_id.innerHTML = `<b>Run:</b> ${finetune_configs_and_runs.active[finetune_configs_and_runs.config.model_name].specific_lora_run_id}`;
295-
lora_switch_checkpoint.innerHTML = `<b>Checkpoint:</b> ${finetune_configs_and_runs.active[finetune_configs_and_runs.config.model_name].specific_checkpoint}`;
298+
lora_switch_run_id.innerHTML = `<b>Run:</b> ${finetune_configs_and_runs.active[model_name].specific_lora_run_id}`;
299+
lora_switch_checkpoint.innerHTML = `<b>Checkpoint:</b> ${finetune_configs_and_runs.active[model_name].specific_checkpoint}`;
296300
} else if (mode == 'latest-best') {
297301
lora_switch_run_id.style.display = 'block';
298302
lora_switch_checkpoint.style.display = 'block';
@@ -303,8 +307,8 @@ function render_lora_switch() {
303307
} else {
304308
lora_switch_run_id.style.display = 'none';
305309
lora_switch_checkpoint.style.display = 'none';
306-
lora_switch_run_id.innerHTML = `<b>Run:</b> ${finetune_configs_and_runs.active[finetune_configs_and_runs.config.model_name].specific_lora_run_id}`;
307-
lora_switch_checkpoint.innerHTML = `<b>Checkpoint:</b> ${finetune_configs_and_runs.active[finetune_configs_and_runs.config.model_name].specific_checkpoint}`;
310+
lora_switch_run_id.innerHTML = `<b>Run:</b> ${finetune_configs_and_runs.active[model_name].specific_lora_run_id}`;
311+
lora_switch_checkpoint.innerHTML = `<b>Checkpoint:</b> ${finetune_configs_and_runs.active[model_name].specific_checkpoint}`;
308312
}
309313
}
310314

@@ -346,7 +350,20 @@ function render_checkpoints(data = []) {
346350
}
347351
row.classList.add('table-success');
348352
}
349-
finetune_switch_activate("specific", selected_lora, cell.dataset.checkpoint);
353+
const finetune_run = finetune_configs_and_runs.finetune_runs.find((run) => run.run_id === selected_lora);
354+
if (finetune_run && finetune_run.model_name !== finetune_configs_and_runs.completion_model) {
355+
let modal = document.getElementById('finetune-tab-model-warning-modal');
356+
let modal_instance = bootstrap.Modal.getOrCreateInstance(modal);
357+
document.querySelector('#finetune-tab-model-warning-modal #model-warning-message').innerHTML = `
358+
<label>
359+
Checkpoint you're about to activate trained for <b>${finetune_run.model_name}</b> model.
360+
Use another checkpoint for <b>${finetune_configs_and_runs.completion_model}</b> model instead.
361+
</label>
362+
`;
363+
modal_instance.show();
364+
} else {
365+
finetune_switch_activate("specific", selected_lora, cell.dataset.checkpoint);
366+
}
350367
});
351368
});
352369
}
@@ -361,7 +378,7 @@ function animate_use_model() {
361378

362379
function finetune_switch_activate(lora_mode, run_id, checkpoint) {
363380
animate_use_model();
364-
const model_name = document.querySelector('#finetune-model').value
381+
const model_name = finetune_configs_and_runs.completion_model;
365382
let send_this = {
366383
"model": model_name,
367384
"lora_mode": lora_mode,
@@ -876,6 +893,8 @@ function start_log_stream(run_id) {
876893
};
877894
fetchData();
878895
}
896+
897+
879898
export async function init() {
880899
let req = await fetch('/tab-finetune.html');
881900
document.querySelector('#finetune').innerHTML = await req.text();

self_hosting_machinery/webgui/static/tab-model-hosting.html

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ <h3>Hosted Models</h3>
99
<tr>
1010
<th>Model</th>
1111
<th>Completion</th>
12+
<th>Finetune</th>
1213
<th>Sharding</th>
1314
<th>Share GPU</th>
1415
<th></th>

self_hosting_machinery/webgui/static/tab-model-hosting.js

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ function render_models_assigned(models) {
146146
row.setAttribute('data-model',index);
147147
const model_name = document.createElement("td");
148148
const completion = document.createElement("td");
149+
const finetune_info = document.createElement("td");
149150
const select_gpus = document.createElement("td");
150151
const gpus_share = document.createElement("td");
151152
const del = document.createElement("td");
@@ -168,6 +169,21 @@ function render_models_assigned(models) {
168169
completion.appendChild(completion_input);
169170
}
170171

172+
if (models_info[index].hasOwnProperty('finetune_info') && models_info[index].finetune_info) {
173+
finetune_info.innerHTML = `
174+
<table cellpadding="5">
175+
<tr>
176+
<td>Run: </td>
177+
<td>${models_info[index].finetune_info.run}</td>
178+
</tr>
179+
<tr>
180+
<td>Checkpoint: </td>
181+
<td>${models_info[index].finetune_info.checkpoint}</td>
182+
</tr>
183+
</table>
184+
`;
185+
}
186+
171187
if (models_info[index].hasOwnProperty('has_sharding') && models_info[index].has_sharding) {
172188
const select_gpus_div = document.createElement("div");
173189
select_gpus_div.setAttribute("class", "btn-group btn-group-sm");
@@ -233,6 +249,7 @@ function render_models_assigned(models) {
233249

234250
row.appendChild(model_name);
235251
row.appendChild(completion);
252+
row.appendChild(finetune_info);
236253
row.appendChild(select_gpus);
237254
row.appendChild(gpus_share);
238255
row.appendChild(del);

self_hosting_machinery/webgui/tab_finetune.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from fastapi.responses import Response, StreamingResponse, JSONResponse
1111

1212
from self_hosting_machinery.scripts import best_lora
13+
from self_hosting_machinery.webgui.selfhost_model_assigner import ModelAssigner
1314
from refact_data_pipeline.finetune.finetune_utils import get_active_loras
1415
from refact_data_pipeline.finetune.finetune_utils import get_finetune_config
1516
from refact_data_pipeline.finetune.finetune_utils import get_finetune_filter_stat
@@ -93,7 +94,7 @@ class TabFinetuneTrainingSetup(BaseModel):
9394

9495
class TabFinetuneRouter(APIRouter):
9596

96-
def __init__(self, models_db: Dict[str, Any], *args, **kwargs):
97+
def __init__(self, model_assigner: ModelAssigner, *args, **kwargs):
9798
super().__init__(*args, **kwargs)
9899
self.add_api_route("/tab-finetune-get", self._tab_finetune_get, methods=["GET"])
99100
self.add_api_route("/tab-finetune-config-and-runs", self._tab_finetune_config_and_runs, methods=["GET"])
@@ -109,7 +110,7 @@ def __init__(self, models_db: Dict[str, Any], *args, **kwargs):
109110
self.add_api_route("/tab-finetune-smart-filter-get", self._tab_finetune_smart_filter_get, methods=["GET"])
110111
self.add_api_route("/tab-finetune-training-setup", self._tab_finetune_training_setup, methods=["POST"])
111112
self.add_api_route("/tab-finetune-training-get", self._tab_finetune_training_get, methods=["GET"])
112-
self._models_db = models_db
113+
self._model_assigner = model_assigner
113114

114115
async def _tab_finetune_get(self):
115116
prog, status = get_prog_and_status_for_ui()
@@ -140,9 +141,11 @@ async def _tab_finetune_get_sources_status(self):
140141
return f"Error: {str(e)}"
141142

142143
async def _tab_finetune_config_and_runs(self):
144+
completion_model = self._model_assigner.model_assignment.get("completion", "")
143145
runs = get_finetune_runs()
144-
config = get_finetune_config(self._models_db)
146+
config = get_finetune_config(self._model_assigner.models_db)
145147
result = {
148+
"completion_model": completion_model,
146149
"finetune_runs": runs,
147150
"config": {
148151
"limit_training_time_minutes": "60",
@@ -151,8 +154,8 @@ async def _tab_finetune_config_and_runs(self):
151154
"auto_delete_n_runs": "5",
152155
**config, # TODO: why we mix finetune config for training and schedule?
153156
},
154-
"active": get_active_loras(self._models_db),
155-
"finetune_latest_best": best_lora.find_best_lora(config["model_name"]),
157+
"active": get_active_loras(self._model_assigner.models_db),
158+
"finetune_latest_best": best_lora.find_best_lora(completion_model),
156159
}
157160
return Response(json.dumps(result, indent=4) + "\n")
158161

@@ -188,7 +191,7 @@ async def _tab_finetune_training_setup(self, post: TabFinetuneTrainingSetup):
188191
async def _tab_finetune_training_get(self):
189192
result = {
190193
"defaults": finetune_train_defaults,
191-
"user_config": get_finetune_config(self._models_db),
194+
"user_config": get_finetune_config(self._model_assigner.models_db),
192195
}
193196
return Response(json.dumps(result, indent=4) + "\n")
194197

@@ -255,7 +258,7 @@ async def _tab_finetune_remove(self, run_id: str):
255258
return JSONResponse("OK")
256259

257260
async def _tab_finetune_activate(self, activate: TabFinetuneActivate):
258-
active_loras = get_active_loras(self._models_db)
261+
active_loras = get_active_loras(self._model_assigner.models_db)
259262
active_loras[activate.model] = activate.dict()
260263
with open(env.CONFIG_ACTIVE_LORA, "w") as f:
261264
json.dump(active_loras, f, indent=4)

self_hosting_machinery/webgui/webgui.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def _routers_list(
7575
TabServerLogRouter(),
7676
TabUploadRouter(),
7777
TabFinetuneRouter(
78-
models_db=model_assigner.models_db),
78+
model_assigner=model_assigner),
7979
TabHostRouter(model_assigner),
8080
TabSettingsRouter(model_assigner),
8181
StaticRouter(),

0 commit comments

Comments
 (0)