@@ -1370,14 +1370,24 @@ class StableDiffusionGGML {
1370
1370
ggml_tensor* encode_first_stage (ggml_context* work_ctx, ggml_tensor* x, bool decode_video = false ) {
1371
1371
int64_t t0 = ggml_time_ms ();
1372
1372
ggml_tensor* result = NULL ;
1373
+ int W = x->ne [0 ] / 8 ;
1374
+ int H = x->ne [1 ] / 8 ;
1375
+ if (vae_tiling && !decode_video) {
1376
+ // TODO wan2.2 vae support?
1377
+ int C = sd_version_is_dit (version) ? 16 : 4 ;
1378
+ if (!use_tiny_autoencoder) {
1379
+ C *= 2 ;
1380
+ }
1381
+ result = ggml_new_tensor_4d (work_ctx, GGML_TYPE_F32, W, H, C, x->ne [3 ]);
1382
+ }
1373
1383
// TODO: args instead of env for tile size / overlap?
1374
1384
if (!use_tiny_autoencoder) {
1375
1385
float tile_overlap = 0 .5f ;
1376
1386
int tile_size_x = 32 ;
1377
1387
int tile_size_y = 32 ;
1378
1388
1379
1389
get_vae_tile_overlap (tile_overlap);
1380
- get_vae_tile_sizes (tile_size_x, tile_size_y, tile_overlap, x-> ne [ 0 ] / 8 , x-> ne [ 1 ] / 8 );
1390
+ get_vae_tile_sizes (tile_size_x, tile_size_y, tile_overlap, W, H );
1381
1391
1382
1392
// TODO: also use an arg for this one?
1383
1393
// multiply tile size for encode to keep the compute buffer size consistent
@@ -1387,7 +1397,7 @@ class StableDiffusionGGML {
1387
1397
process_vae_input_tensor (x);
1388
1398
if (vae_tiling && !decode_video) {
1389
1399
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
1390
- first_stage_model->compute (n_threads, in, true , &out, NULL );
1400
+ first_stage_model->compute (n_threads, in, false , &out, work_ctx );
1391
1401
};
1392
1402
sd_tiling_non_square (x, result, 8 , tile_size_x, tile_size_y, tile_overlap, on_tiling);
1393
1403
} else {
@@ -1398,7 +1408,7 @@ class StableDiffusionGGML {
1398
1408
if (vae_tiling && !decode_video) {
1399
1409
// split latent in 32x32 tiles and compute in several steps
1400
1410
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
1401
- tae_first_stage->compute (n_threads, in, true , &out, NULL );
1411
+ tae_first_stage->compute (n_threads, in, false , &out, NULL );
1402
1412
};
1403
1413
sd_tiling (x, result, 8 , 64 , 0 .5f , on_tiling);
1404
1414
} else {
0 commit comments