Skip to content

Commit 2b37d44

Browse files
committed
Use torch.accelerator in DCGAN example
1 parent 7ebe86d commit 2b37d44

File tree

2 files changed

+3
-7
lines changed

2 files changed

+3
-7
lines changed

dcgan/main.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -152,10 +152,8 @@ def __init__(self, ngpu):
152152

153153
def forward(self, input):
154154

155-
if input.is_cuda and self.ngpu > 1:
155+
if (input.is_cuda or input.is_xpu) and self.ngpu > 1:
156156
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
157-
if input.is_xpu and self.ngpu > 1:
158-
output = nn.DataParallel(self.main, input, range(self.ngpu))
159157
else:
160158
output = self.main(input)
161159
return output
@@ -194,10 +192,8 @@ def __init__(self, ngpu):
194192
)
195193

196194
def forward(self, input):
197-
if input.is_cuda and self.ngpu > 1:
195+
if (input.is_cuda or input.is_xpu) and self.ngpu > 1:
198196
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
199-
if input.is_xpu and self.ngpu > 1:
200-
output = nn.DataParallel(self.main, input, range(self.ngpu))
201197
else:
202198
output = self.main(input)
203199

run_python_examples.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ case $USE_CUDA in
5858
esac
5959

6060
function dcgan() {
61-
uv run main.py --dataset fake --accel --dry-run || error "dcgan failed"
61+
uv run main.py --dataset fake $ACCEL_FLAG --dry-run || error "dcgan failed"
6262
}
6363

6464
function fast_neural_style() {

0 commit comments

Comments
 (0)