Skip to content

Commit cdf72db

Browse files
committed
Inject mock API token into tests
Signed-off-by: Mattt Zmuda <[email protected]>
1 parent 1fef018 commit cdf72db

File tree

5 files changed

+20
-10
lines changed

5 files changed

+20
-10
lines changed

tests/conftest.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,18 @@
11
import os
2+
from unittest import mock
23

34
import pytest
45

56

7+
@pytest.fixture(scope="session")
8+
def mock_replicate_api_token(scope="class"):
9+
if os.environ.get("REPLICATE_API_TOKEN", "") != "":
10+
yield
11+
else:
12+
with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "test-token"}):
13+
yield
14+
15+
616
@pytest.fixture(scope="module")
717
def vcr_config():
818
return {"allowed_hosts": ["api.replicate.com"], "filter_headers": ["authorization"]}

tests/test_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
@pytest.mark.vcr("models-get.yaml")
77
@pytest.mark.asyncio
8-
async def test_models_get():
8+
async def test_models_get(mock_replicate_api_token):
99
model = replicate.models.get("stability-ai/sdxl")
1010

1111
assert model is not None

tests/test_prediction.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
@pytest.mark.vcr("predictions-create.yaml")
77
@pytest.mark.asyncio
8-
async def test_predictions_create():
8+
async def test_predictions_create(mock_replicate_api_token):
99
input = {
1010
"prompt": "a studio photo of a rainbow colored corgi",
1111
"width": 512,
@@ -29,7 +29,7 @@ async def test_predictions_create():
2929

3030
@pytest.mark.vcr("predictions-get.yaml")
3131
@pytest.mark.asyncio
32-
async def test_predictions_get():
32+
async def test_predictions_get(mock_replicate_api_token):
3333
id = "vgcm4plb7tgzlyznry5d5jkgvu"
3434

3535
prediction = replicate.predictions.get(id)
@@ -39,7 +39,7 @@ async def test_predictions_get():
3939

4040
@pytest.mark.vcr("predictions-cancel.yaml")
4141
@pytest.mark.asyncio
42-
async def test_predictions_cancel():
42+
async def test_predictions_cancel(mock_replicate_api_token):
4343
input = {
4444
"prompt": "a studio photo of a rainbow colored corgi",
4545
"width": 512,

tests/test_run.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
@pytest.mark.vcr("run.yaml")
88
@pytest.mark.asyncio
9-
async def test_run():
9+
async def test_run(mock_replicate_api_token):
1010
version = "a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5"
1111

1212
input = {
@@ -28,6 +28,6 @@ async def test_run():
2828

2929

3030
@pytest.mark.vcr
31-
def test_run_with_invalid_identifier():
31+
def test_run_with_invalid_identifier(mock_replicate_api_token):
3232
with pytest.raises(ReplicateError):
3333
replicate.run("invalid")

tests/test_training.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
@pytest.mark.vcr("trainings-create.yaml")
1010
@pytest.mark.asyncio
11-
async def test_trainings_create():
11+
async def test_trainings_create(mock_replicate_api_token):
1212
training = replicate.trainings.create(
1313
"stability-ai/sdxl:a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5",
1414
input={
@@ -24,7 +24,7 @@ async def test_trainings_create():
2424

2525
@pytest.mark.vcr("trainings-create__invalid-destination.yaml")
2626
@pytest.mark.asyncio
27-
async def test_trainings_create_with_invalid_destination():
27+
async def test_trainings_create_with_invalid_destination(mock_replicate_api_token):
2828
with pytest.raises(ReplicateException):
2929
replicate.trainings.create(
3030
"stability-ai/sdxl:a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5",
@@ -37,7 +37,7 @@ async def test_trainings_create_with_invalid_destination():
3737

3838
@pytest.mark.vcr("trainings-get.yaml")
3939
@pytest.mark.asyncio
40-
async def test_trainings_get():
40+
async def test_trainings_get(mock_replicate_api_token):
4141
id = "ckcbvmtbvg6di3b3uhvccytnfm"
4242

4343
training = replicate.trainings.get(id)
@@ -48,7 +48,7 @@ async def test_trainings_get():
4848

4949
@pytest.mark.vcr("trainings-cancel.yaml")
5050
@pytest.mark.asyncio
51-
async def test_trainings_cancel():
51+
async def test_trainings_cancel(mock_replicate_api_token):
5252
input = {
5353
"input_images": input_images_url,
5454
"use_face_detection_instead": True,

0 commit comments

Comments
 (0)