Skip to content

Commit dd9dc12

Browse files
author
Raphael Gibson
committed
feat: add unique constraint param to Field function
1 parent 02da85c commit dd9dc12

File tree

2 files changed

+97
-0
lines changed

2 files changed

+97
-0
lines changed

sqlmodel/main.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None:
7070
primary_key = kwargs.pop("primary_key", False)
7171
nullable = kwargs.pop("nullable", Undefined)
7272
foreign_key = kwargs.pop("foreign_key", Undefined)
73+
unique = kwargs.pop("unique", False)
7374
index = kwargs.pop("index", Undefined)
7475
sa_column = kwargs.pop("sa_column", Undefined)
7576
sa_column_args = kwargs.pop("sa_column_args", Undefined)
@@ -89,6 +90,7 @@ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None:
8990
self.primary_key = primary_key
9091
self.nullable = nullable
9192
self.foreign_key = foreign_key
93+
self.unique = unique
9294
self.index = index
9395
self.sa_column = sa_column
9496
self.sa_column_args = sa_column_args
@@ -150,6 +152,7 @@ def Field(
150152
regex: str = None,
151153
primary_key: bool = False,
152154
foreign_key: Optional[Any] = None,
155+
unique: bool = False,
153156
nullable: Union[bool, UndefinedType] = Undefined,
154157
index: Union[bool, UndefinedType] = Undefined,
155158
sa_column: Union[Column, UndefinedType] = Undefined,
@@ -180,6 +183,7 @@ def Field(
180183
regex=regex,
181184
primary_key=primary_key,
182185
foreign_key=foreign_key,
186+
unique=unique,
183187
nullable=nullable,
184188
index=index,
185189
sa_column=sa_column,
@@ -424,12 +428,14 @@ def get_column_from_field(field: ModelField) -> Column:
424428
nullable = field_nullable
425429
args = []
426430
foreign_key = getattr(field.field_info, "foreign_key", None)
431+
unique = getattr(field.field_info, "unique", False)
427432
if foreign_key:
428433
args.append(ForeignKey(foreign_key))
429434
kwargs = {
430435
"primary_key": primary_key,
431436
"nullable": nullable,
432437
"index": index,
438+
"unique": unique
433439
}
434440
sa_default = Undefined
435441
if field.field_info.default_factory:

tests/test_main.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import pytest
2+
from typing import Optional
3+
4+
from sqlmodel import Field, Session, SQLModel, create_engine
5+
from sqlalchemy.exc import IntegrityError
6+
7+
8+
def test_should_allow_duplicate_row_if_unique_constraint_is_not_passed(clear_sqlmodel):
9+
class Hero(SQLModel, table=True):
10+
id: Optional[int] = Field(default=None, primary_key=True)
11+
name: str
12+
secret_name: str
13+
age: Optional[int] = None
14+
15+
hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson")
16+
hero_2 = Hero(name="Deadpond", secret_name="Dive Wilson")
17+
18+
engine = create_engine("sqlite://")
19+
20+
SQLModel.metadata.create_all(engine)
21+
22+
with Session(engine) as session:
23+
session.add(hero_1)
24+
session.commit()
25+
session.refresh(hero_1)
26+
27+
with Session(engine) as session:
28+
session.add(hero_2)
29+
session.commit()
30+
session.refresh(hero_2)
31+
32+
with Session(engine) as session:
33+
heroes = session.query(Hero).all()
34+
assert len(heroes) == 2
35+
assert heroes[0].name == heroes[1].name
36+
37+
38+
def test_should_allow_duplicate_row_if_unique_constraint_is_false(clear_sqlmodel):
39+
class Hero(SQLModel, table=True):
40+
id: Optional[int] = Field(default=None, primary_key=True)
41+
name: str
42+
secret_name: str = Field(unique=False)
43+
age: Optional[int] = None
44+
45+
hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson")
46+
hero_2 = Hero(name="Deadpond", secret_name="Dive Wilson")
47+
48+
engine = create_engine("sqlite://")
49+
50+
SQLModel.metadata.create_all(engine)
51+
52+
with Session(engine) as session:
53+
session.add(hero_1)
54+
session.commit()
55+
session.refresh(hero_1)
56+
57+
with Session(engine) as session:
58+
session.add(hero_2)
59+
session.commit()
60+
session.refresh(hero_2)
61+
62+
with Session(engine) as session:
63+
heroes = session.query(Hero).all()
64+
assert len(heroes) == 2
65+
assert heroes[0].name == heroes[1].name
66+
67+
68+
def test_should_raise_exception_when_try_to_duplicate_row_if_unique_constraint_is_true(clear_sqlmodel):
69+
class Hero(SQLModel, table=True):
70+
id: Optional[int] = Field(default=None, primary_key=True)
71+
name: str
72+
secret_name: str = Field(unique=True)
73+
age: Optional[int] = None
74+
75+
hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson")
76+
hero_2 = Hero(name="Deadpond", secret_name="Dive Wilson")
77+
78+
engine = create_engine("sqlite://")
79+
80+
SQLModel.metadata.create_all(engine)
81+
82+
with Session(engine) as session:
83+
session.add(hero_1)
84+
session.commit()
85+
session.refresh(hero_1)
86+
87+
with pytest.raises(IntegrityError):
88+
with Session(engine) as session:
89+
session.add(hero_2)
90+
session.commit()
91+
session.refresh(hero_2)

0 commit comments

Comments
 (0)