diff --git a/src/huggingface_hub/cli/repo.py b/src/huggingface_hub/cli/repo.py index 78ae694bba..793c1247c8 100644 --- a/src/huggingface_hub/cli/repo.py +++ b/src/huggingface_hub/cli/repo.py @@ -21,6 +21,7 @@ hf repo create my-cool-model --private """ +import enum from typing import Annotated, Optional import typer @@ -44,8 +45,16 @@ logger = logging.get_logger(__name__) repo_cli = typer_factory(help="Manage repos on the Hub.") -tag_app = typer_factory(help="Manage tags for a repo on the Hub.") -repo_cli.add_typer(tag_app, name="tag") +tag_cli = typer_factory(help="Manage tags for a repo on the Hub.") +branch_cli = typer_factory(help="Manage branches for a repo on the Hub.") +repo_cli.add_typer(tag_cli, name="tag") +repo_cli.add_typer(branch_cli, name="branch") + + +class GatedChoices(str, enum.Enum): + auto = "auto" + manual = "manual" + false = "false" @repo_cli.command("create", help="Create a new repo on the Hub.") @@ -87,7 +96,130 @@ def repo_create( print(f"Your repo is now available at {ANSI.bold(repo_url)}") -@tag_app.command("create", help="Create a tag for a repo.") +@repo_cli.command("delete", help="Delete a repo from the Hub. this is an irreversible operation.") +def repo_delete( + repo_id: RepoIdArg, + repo_type: RepoTypeOpt = RepoType.model, + token: TokenOpt = None, + missing_ok: Annotated[ + bool, + typer.Option( + help="If set to True, do not raise an error if repo does not exist.", + ), + ] = False, +) -> None: + api = get_hf_api(token=token) + api.delete_repo( + repo_id=repo_id, + repo_type=repo_type.value, + missing_ok=missing_ok, + ) + print(f"Successfully deleted {ANSI.bold(repo_id)} on the Hub.") + + +@repo_cli.command("move", help="Move a repository from a namespace to another namespace.") +def repo_move( + from_id: RepoIdArg, + to_id: RepoIdArg, + token: TokenOpt = None, + repo_type: RepoTypeOpt = RepoType.model, +) -> None: + api = get_hf_api(token=token) + api.move_repo( + from_id=from_id, + to_id=to_id, + repo_type=repo_type.value, + ) + print(f"Successfully moved {ANSI.bold(from_id)} to {ANSI.bold(to_id)} on the Hub.") + + +@repo_cli.command("settings", help="Update the settings of a repository.") +def repo_settings( + repo_id: RepoIdArg, + gated: Annotated[ + Optional[GatedChoices], + typer.Option( + help="The gated status for the repository.", + ), + ] = None, + private: Annotated[ + Optional[bool], + typer.Option( + help="Whether the repository should be private.", + ), + ] = None, + xet_enabled: Annotated[ + Optional[bool], + typer.Option( + help=" Whether the repository should be enabled for Xet Storage.", + ), + ] = None, + token: TokenOpt = None, + repo_type: RepoTypeOpt = RepoType.model, +) -> None: + api = get_hf_api(token=token) + api.update_repo_settings( + repo_id=repo_id, + gated=(gated.value if gated else None), # type: ignore [arg-type] + private=private, + xet_enabled=xet_enabled, + repo_type=repo_type.value, + ) + print(f"Successfully updated the settings of {ANSI.bold(repo_id)} on the Hub.") + + +@branch_cli.command("create", help="Create a new branch for a repo on the Hub.") +def branch_create( + repo_id: RepoIdArg, + branch: Annotated[ + str, + typer.Argument( + help="The name of the branch to create.", + ), + ], + revision: RevisionOpt = None, + token: TokenOpt = None, + repo_type: RepoTypeOpt = RepoType.model, + exist_ok: Annotated[ + bool, + typer.Option( + help="If set to True, do not raise an error if branch already exists.", + ), + ] = False, +) -> None: + api = get_hf_api(token=token) + api.create_branch( + repo_id=repo_id, + branch=branch, + revision=revision, + repo_type=repo_type.value, + exist_ok=exist_ok, + ) + print(f"Successfully created {ANSI.bold(branch)} branch on {repo_type.value} {ANSI.bold(repo_id)}") + + +@branch_cli.command("delete", help="Delete a branch from a repo on the Hub.") +def branch_delete( + repo_id: RepoIdArg, + branch: Annotated[ + str, + typer.Argument( + help="The name of the branch to delete.", + ), + ], + token: TokenOpt = None, + repo_type: RepoTypeOpt = RepoType.model, +) -> None: + api = get_hf_api(token=token) + api.delete_branch( + repo_id=repo_id, + branch=branch, + repo_type=repo_type.value, + ) + print(f"Successfully deleted {ANSI.bold(branch)} branch on {repo_type.value} {ANSI.bold(repo_id)}") + + +@tag_cli.command("create", help="Create a tag for a repo.") def tag_create( repo_id: RepoIdArg, tag: Annotated[ @@ -127,7 +259,7 @@ def tag_create( print(f"Tag {ANSI.bold(tag)} created on {ANSI.bold(repo_id)}") -@tag_app.command("list", help="List tags for a repo.") +@tag_cli.command("list", help="List tags for a repo.") def tag_list( repo_id: RepoIdArg, token: TokenOpt = None, @@ -152,7 +284,7 @@ def tag_list( print(t.name) -@tag_app.command("delete", help="Delete a tag for a repo.") +@tag_cli.command("delete", help="Delete a tag for a repo.") def tag_delete( repo_id: RepoIdArg, tag: Annotated[ diff --git a/tests/test_cli.py b/tests/test_cli.py index 0498303224..391502fc9c 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -778,6 +778,208 @@ def test_tag_delete_basic(self, runner: CliRunner) -> None: api.delete_tag.assert_called_once_with(repo_id=DUMMY_MODEL_ID, tag="1.0", repo_type="model") +class TestBranchCommands: + def test_branch_create_basic(self, runner: CliRunner) -> None: + with patch("huggingface_hub.cli.repo.get_hf_api") as api_cls: + api = api_cls.return_value + result = runner.invoke(app, ["repo", "branch", "create", DUMMY_MODEL_ID, "dev"]) + assert result.exit_code == 0 + api_cls.assert_called_once_with(token=None) + api.create_branch.assert_called_once_with( + repo_id=DUMMY_MODEL_ID, + branch="dev", + revision=None, + repo_type="model", + exist_ok=False, + ) + + def test_branch_create_with_all_options(self, runner: CliRunner) -> None: + with patch("huggingface_hub.cli.repo.get_hf_api") as api_cls: + api = api_cls.return_value + result = runner.invoke( + app, + [ + "repo", + "branch", + "create", + DUMMY_MODEL_ID, + "dev", + "--repo-type", + "dataset", + "--revision", + "v1.0.0", + "--token", + "my-token", + "--exist-ok", + ], + ) + assert result.exit_code == 0 + api_cls.assert_called_once_with(token="my-token") + api.create_branch.assert_called_once_with( + repo_id=DUMMY_MODEL_ID, + branch="dev", + revision="v1.0.0", + repo_type="dataset", + exist_ok=True, + ) + + def test_branch_delete_basic(self, runner: CliRunner) -> None: + with patch("huggingface_hub.cli.repo.get_hf_api") as api_cls: + api = api_cls.return_value + result = runner.invoke(app, ["repo", "branch", "delete", DUMMY_MODEL_ID, "dev"]) + assert result.exit_code == 0 + api_cls.assert_called_once_with(token=None) + api.delete_branch.assert_called_once_with( + repo_id=DUMMY_MODEL_ID, + branch="dev", + repo_type="model", + ) + + def test_branch_delete_with_all_options(self, runner: CliRunner) -> None: + with patch("huggingface_hub.cli.repo.get_hf_api") as api_cls: + api = api_cls.return_value + result = runner.invoke( + app, + [ + "repo", + "branch", + "delete", + DUMMY_MODEL_ID, + "dev", + "--repo-type", + "dataset", + "--token", + "my-token", + ], + ) + assert result.exit_code == 0 + api_cls.assert_called_once_with(token="my-token") + api.delete_branch.assert_called_once_with( + repo_id=DUMMY_MODEL_ID, + branch="dev", + repo_type="dataset", + ) + + +class TestRepoMoveCommand: + def test_repo_move_basic(self, runner: CliRunner) -> None: + with patch("huggingface_hub.cli.repo.get_hf_api") as api_cls: + api = api_cls.return_value + result = runner.invoke(app, ["repo", "move", DUMMY_MODEL_ID, "new-id"]) + assert result.exit_code == 0 + api_cls.assert_called_once_with(token=None) + api.move_repo.assert_called_once_with( + from_id=DUMMY_MODEL_ID, + to_id="new-id", + repo_type="model", + ) + + def test_repo_move_with_all_options(self, runner: CliRunner) -> None: + with patch("huggingface_hub.cli.repo.get_hf_api") as api_cls: + api = api_cls.return_value + result = runner.invoke( + app, + [ + "repo", + "move", + DUMMY_MODEL_ID, + "new-id", + "--repo-type", + "dataset", + "--token", + "my-token", + ], + ) + assert result.exit_code == 0 + api_cls.assert_called_once_with(token="my-token") + api.move_repo.assert_called_once_with( + from_id=DUMMY_MODEL_ID, + to_id="new-id", + repo_type="dataset", + ) + + +class TestRepoSettingsCommand: + def test_repo_settings_basic(self, runner: CliRunner) -> None: + with patch("huggingface_hub.cli.repo.get_hf_api") as api_cls: + api = api_cls.return_value + result = runner.invoke(app, ["repo", "settings", DUMMY_MODEL_ID]) + assert result.exit_code == 0 + api_cls.assert_called_once_with(token=None) + api.update_repo_settings.assert_called_once_with( + repo_id=DUMMY_MODEL_ID, + gated=None, + private=None, + xet_enabled=None, + repo_type="model", + ) + + def test_repo_settings_with_all_options(self, runner: CliRunner) -> None: + with patch("huggingface_hub.cli.repo.get_hf_api") as api_cls: + api = api_cls.return_value + result = runner.invoke( + app, + [ + "repo", + "settings", + DUMMY_MODEL_ID, + "--gated", + "manual", + "--private", + "--repo-type", + "dataset", + "--token", + "my-token", + ], + ) + assert result.exit_code == 0 + api_cls.assert_called_once_with(token="my-token") + kwargs = api.update_repo_settings.call_args.kwargs + assert kwargs["repo_id"] == DUMMY_MODEL_ID + assert kwargs["repo_type"] == "dataset" + assert kwargs["private"] is True + assert kwargs["xet_enabled"] is None + assert kwargs["gated"] == "manual" + + +class TestRepoDeleteCommand: + def test_repo_delete_basic(self, runner: CliRunner) -> None: + with patch("huggingface_hub.cli.repo.get_hf_api") as api_cls: + api = api_cls.return_value + result = runner.invoke(app, ["repo", "delete", DUMMY_MODEL_ID]) + assert result.exit_code == 0 + api_cls.assert_called_once_with(token=None) + api.delete_repo.assert_called_once_with( + repo_id=DUMMY_MODEL_ID, + repo_type="model", + missing_ok=False, + ) + + def test_repo_delete_with_all_options(self, runner: CliRunner) -> None: + with patch("huggingface_hub.cli.repo.get_hf_api") as api_cls: + api = api_cls.return_value + result = runner.invoke( + app, + [ + "repo", + "delete", + DUMMY_MODEL_ID, + "--repo-type", + "dataset", + "--token", + "my-token", + "--missing-ok", + ], + ) + assert result.exit_code == 0 + api_cls.assert_called_once_with(token="my-token") + api.delete_repo.assert_called_once_with( + repo_id=DUMMY_MODEL_ID, + repo_type="dataset", + missing_ok=True, + ) + + @contextmanager def tmp_current_directory() -> Generator[str, None, None]: with SoftTemporaryDirectory() as tmp_dir: