-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvector_quantize.py
2240 lines (1813 loc) · 86 KB
/
vector_quantize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
from __future__ import annotations
from functools import partial, cache
from collections import namedtuple
import torch
from torch.nn import Module
from torch import nn, einsum, Tensor
import torch.nn.functional as F
import torch.distributed as distributed
from torch.optim import Optimizer
from torch.amp import autocast
import einx
from einops import rearrange, repeat, reduce, pack, unpack
from typing import Callable
def exists(val):
"""
检查一个值是否存在(不为 None)。
参数:
val: 需要检查的值。
返回:
bool: 如果 val 不为 None,则返回 True;否则返回 False。
"""
return val is not None
def default(val, d):
"""
返回可选值或默认值。
参数:
val: 需要检查的可选值。
d: 默认值。
返回:
Any: 如果 val 存在,则返回 val;否则返回 d。
"""
return val if exists(val) else d
def noop(*args, **kwargs):
"""
空操作函数,不执行任何操作。
参数:
*args: 任意位置参数。
**kwargs: 任意关键字参数。
"""
pass
def identity(t):
"""
恒等函数,返回输入值不变。
参数:
t: 输入值。
返回:
Any: 输入值 t。
"""
return t
def l2norm(t, dim = -1, eps = 1e-6):
"""
对张量 t 进行 L2 归一化。
参数:
t (Tensor): 输入张量。
dim (int, 可选): 需要归一化的维度。默认值为 -1。
eps (float, 可选): 防止除以零的极小值。默认值为 1e-6。
返回:
Tensor: 归一化后的张量。
"""
return F.normalize(t, p = 2, dim = dim, eps = eps)
def safe_div(num, den, eps = 1e-6):
"""
安全除法函数,防止除以零。
参数:
num (Tensor): 分子张量。
den (Tensor): 分母张量。
eps (float, 可选): 防止除以零的极小值。默认值为 1e-6。
返回:
Tensor: 除法结果。
"""
return num / den.clamp(min = eps)
def Sequential(*modules):
"""
创建一个顺序模型,仅包含存在的模块。
参数:
*modules: 任意数量的 nn.Module 实例。
返回:
nn.Sequential: 包含所有存在模块的顺序模型。如果没有模块,则返回 None;如果只有一个模块,则返回该模块。
"""
modules = [*filter(exists, modules)]
if len(modules) == 0:
return None
elif len(modules) == 1:
return modules[0]
return nn.Sequential(*modules)
def cdist(x, y):
"""
计算两个张量 x 和 y 之间的成对欧几里得距离。
参数:
x (Tensor): 第一个输入张量,形状为 (batch_size, n, d)。
y (Tensor): 第二个输入张量,形状为 (batch_size, m, d)。
返回:
Tensor: 欧几里得距离矩阵,形状为 (batch_size, n, m)。
"""
x2 = reduce(x ** 2, 'b n d -> b n', 'sum') # 计算 x 的平方和,形状为 (batch_size, n)
y2 = reduce(y ** 2, 'b n d -> b n', 'sum') # 计算 y 的平方和,形状为 (batch_size, m)
xy = einsum('b i d, b j d -> b i j', x, y) * -2 # 计算 x 和 y 的点积,形状为 (batch_size, n, m)
# 计算欧几里得距离,并确保结果为非负数
return (rearrange(x2, 'b i -> b i 1') + rearrange(y2, 'b j -> b 1 j') + xy).clamp(min = 0).sqrt()
def log(t, eps = 1e-20):
"""
对张量 t 进行对数运算,并防止数值下溢。
参数:
t (Tensor): 输入张量。
eps (float, 可选): 防止数值下溢的极小值。默认值为 1e-20。
返回:
Tensor: 对数运算后的张量。
"""
return torch.log(t.clamp(min = eps))
def entropy(prob, eps = 1e-5):
"""
计算概率分布的熵。
参数:
prob (Tensor): 概率分布张量,形状为 (batch_size, ..., n)。
eps (float, 可选): 防止数值下溢的极小值。默认值为 1e-5。
返回:
Tensor: 熵,形状为 (batch_size, ...)。
"""
return (-prob * log(prob, eps = eps)).sum(dim = -1)
def ema_inplace(old, new, decay):
"""
对旧张量 old 进行指数移动平均(EMA)更新。
参数:
old (Tensor): 旧张量。
new (Tensor): 新张量。
decay (float): 衰减因子。
"""
is_mps = str(old.device).startswith('mps:')
if not is_mps:
old.lerp_(new, 1 - decay)
else:
old.mul_(decay).add_(new * (1 - decay))
def pack_one(t, pattern):
"""
将单个张量 t 按照指定的 pattern 打包,并返回一个解包函数。
参数:
t (Tensor): 需要打包的张量。
pattern (str): 打包的模式。
返回:
Tuple[Tensor, callable]: 返回打包后的张量和解包函数。
"""
packed, ps = pack([t], pattern)
def unpack_one(to_unpack, unpack_pattern = None):
"""
解包函数,将打包后的张量 to_unpack 解包回原始形状。
参数:
to_unpack (Tensor): 需要解包的张量。
unpack_pattern (Optional[str], 可选): 解包的 pattern。如果未提供,则使用原始的 pattern。
返回:
Tensor: 解包后的张量。
"""
unpacked, = unpack(to_unpack, ps, default(unpack_pattern, pattern))
return unpacked
return packed, unpack_one
def lens_to_mask(lens, max_length):
"""
根据序列长度 lens 生成掩码张量。
参数:
lens (Tensor): 序列长度张量,形状为 (batch_size,)。
max_length (int): 最大序列长度。
返回:
Tensor: 掩码张量,形状为 (batch_size, max_length)。
"""
# 生成序列索引张量
seq = torch.arange(max_length, device = lens.device)
# 生成掩码,标记有效位置
return seq < lens[:, None]
def uniform_init(*shape):
"""
使用 Kaiming 均匀初始化方法初始化一个张量。
参数:
*shape: 张量的形状参数。
返回:
Tensor: 初始化后的张量。
"""
t = torch.empty(shape)
nn.init.kaiming_uniform_(t)
return t
def gumbel_noise(t):
"""
生成 Gumbel 噪声。
参数:
t (Tensor): 输入张量,用于确定噪声的形状。
返回:
Tensor: 与输入张量形状相同的 Gumbel 噪声。
"""
noise = torch.zeros_like(t).uniform_(0, 1)
return -log(-log(noise))
def gumbel_sample(
logits,
temperature = 1.,
stochastic = False,
straight_through = False,
dim = -1,
training = True
):
"""
使用 Gumbel-Softmax 对输入的对数概率进行采样。
参数:
logits (Tensor): 输入的对数概率张量。
temperature (float, 可选): 温度参数。默认值为 1.0。
stochastic (bool, 可选): 是否进行随机采样。默认值为 False。
straight_through (bool, 可选): 是否使用直通梯度估计。默认值为 False。
dim (int, 可选): 沿着哪个维度进行采样。默认值为 -1。
training (bool, 可选): 是否在训练模式下进行采样。默认值为 True。
返回:
Tuple[Tensor, Tensor]: 返回采样索引和对应的 one-hot 编码。
"""
# 获取数据类型和采样维度的大小
dtype, size = logits.dtype, logits.shape[dim]
if training and stochastic and temperature > 0:
# 如果在训练模式下进行随机采样,并且温度大于 0,则应用 Gumbel-Softmax
# 计算采样对数概率
sampling_logits = (logits / temperature) + gumbel_noise(logits)
else:
# 否则,直接使用输入的对数概率
sampling_logits = logits
# 沿着指定维度进行 argmax 操作,得到采样索引
ind = sampling_logits.argmax(dim = dim)
# 将采样索引转换为 one-hot 编码
one_hot = F.one_hot(ind, size).type(dtype)
if not straight_through or temperature <= 0. or not training:
# 如果不使用直通梯度估计,或者温度小于等于 0,或者不在训练模式下,则返回采样索引和 one-hot 编码
return ind, one_hot
# 计算 softmax 概率
π1 = (logits / temperature).softmax(dim = dim)
# 应用直通梯度估计,保留梯度信息
one_hot = one_hot + π1 - π1.detach()
# 返回采样索引和修正后的 one-hot 编码
return ind, one_hot
def laplace_smoothing(x, n_categories, eps = 1e-5, dim = -1):
"""
应用 Laplace 平滑到类别分布上。
参数:
x (Tensor): 输入的类别分布张量。
n_categories (int): 类别数量。
eps (float, 可选): 平滑因子。默认值为 1e-5。
dim (int, 可选): 沿着哪个维度进行平滑。默认值为 -1。
返回:
Tensor: 平滑后的类别分布。
"""
denom = x.sum(dim = dim, keepdim = True)
return (x + eps) / (denom + n_categories * eps)
def sample_vectors(samples, num):
"""
从样本中随机采样指定数量的向量。
参数:
samples (Tensor): 输入的样本张量。
num (int): 需要采样的数量。
返回:
Tensor: 采样后的样本。
"""
num_samples, device = samples.shape[0], samples.device
if num_samples >= num:
indices = torch.randperm(num_samples, device = device)[:num]
else:
indices = torch.randint(0, num_samples, (num,), device = device)
return samples[indices]
def batched_sample_vectors(samples, num):
"""
对批量样本进行批量采样。
参数:
samples (Tensor): 批量样本张量,形状为 (batch_size, ...)。
num (int): 每个样本需要采样的数量。
返回:
Tensor: 批量采样后的样本,形状为 (batch_size, num, ...)。
"""
return torch.stack([sample_vectors(sample, num) for sample in samples.unbind(dim = 0)], dim = 0)
def pad_shape(shape, size, dim = 0):
"""
对形状列表进行填充。
参数:
shape (List[int]): 原始形状列表。
size (int): 填充的大小。
dim (int, 可选): 需要填充的维度。默认值为 0。
返回:
List[int]: 填充后的形状列表。
"""
return [size if i == dim else s for i, s in enumerate(shape)]
def sample_multinomial(total_count, probs):
"""
从多项式分布中采样。
参数:
total_count (int): 总计数。
probs (Tensor): 多项式分布的概率张量。
返回:
Tensor: 采样结果,形状与 probs 相同。
"""
device = probs.device
probs = probs.cpu()
# 创建一个新的张量,形状与 probs 相同,值为 total_count
total_count = probs.new_full((), total_count)
# 创建一个新的张量,形状与 probs 相同,值为 1
remainder = probs.new_ones(())
# 创建一个与 probs 形状相同的空张量,用于存储采样结果
sample = torch.empty_like(probs, dtype = torch.long)
for i, p in enumerate(probs):
s = torch.binomial(total_count, p / remainder) # 从二项分布中采样
sample[i] = s # 存储采样结果
total_count -= s # 更新总计数
remainder -= p # 更新剩余概率
# 确保总计数为 0
assert total_count == 0, f'invalid total count {total_count}'
# 返回采样结果,并移动回原始设备
return sample.to(device)
def all_gather_sizes(x, dim):
"""
收集所有进程在指定维度上的尺寸。
参数:
x (Tensor): 输入张量。
dim (int): 需要收集尺寸的维度。
返回:
Tensor: 所有进程在该维度上的尺寸组成的张量。
"""
# 获取指定维度的尺寸,并转换为长整型张量
size = torch.tensor(x.shape[dim], dtype = torch.long, device = x.device)
# 创建一个列表,包含与 world_size 相同数量的空张量
all_sizes = [torch.empty_like(size) for _ in range(distributed.get_world_size())]
# 在所有进程之间收集尺寸信息
distributed.all_gather(all_sizes, size)
# 将所有尺寸堆叠成一个张量并返回
return torch.stack(all_sizes)
def all_gather_variably_sized(x, sizes, dim = 0):
"""
收集不同尺寸的张量。
参数:
x (Tensor): 输入张量。
sizes (List[int]): 每个进程在指定维度上的尺寸列表。
dim (int, 可选): 需要收集的维度。默认值为 0。
返回:
List[Tensor]: 收集到的所有张量列表。
"""
# 获取当前进程的 rank
rank = distributed.get_rank()
# 初始化收集到的张量列表
all_x = []
for i, size in enumerate(sizes):
# 如果当前进程的 rank 与索引 i 相同,则使用输入张量 x
# 否则,创建一个与 x 形状相同但指定维度填充为 size 的空张量
t = x if i == rank else x.new_empty(pad_shape(x.shape, size, dim))
# 从源进程 i 广播张量 t
distributed.broadcast(t, src = i, async_op = True)
# 将广播后的张量添加到列表中
all_x.append(t)
# 等待所有进程完成广播操作
distributed.barrier()
return all_x
def sample_vectors_distributed(local_samples, num):
"""
在分布式环境中对本地样本进行采样,并收集所有样本。
参数:
local_samples (Tensor): 本地样本张量。
num (int): 需要采样的总数量。
返回:
Tensor: 采样后的样本张量。
"""
# 重塑本地样本张量的形状
local_samples = rearrange(local_samples, '1 ... -> ...')
# 获取当前进程的 rank
rank = distributed.get_rank()
# 收集所有进程在指定维度上的样本数量
all_num_samples = all_gather_sizes(local_samples, dim = 0)
if rank == 0:
# 如果当前进程是 rank 0,则根据所有进程的总样本数量进行多项式采样,确定每个进程需要采样的数量
samples_per_rank = sample_multinomial(num, all_num_samples / all_num_samples.sum())
else:
# 否则,创建一个与 all_num_samples 形状相同的空张量
samples_per_rank = torch.empty_like(all_num_samples)
# 从 rank 0 广播每个进程需要采样的数量
distributed.broadcast(samples_per_rank, src = 0)
# 将采样数量转换为列表
samples_per_rank = samples_per_rank.tolist()
# 对本地样本进行采样
local_samples = sample_vectors(local_samples, samples_per_rank[rank])
# 收集所有进程采样后的样本
all_samples = all_gather_variably_sized(local_samples, samples_per_rank, dim = 0)
# 将所有样本连接成一个张量
out = torch.cat(all_samples, dim = 0)
# 重塑输出张量的形状并返回
return rearrange(out, '... -> 1 ...')
def batched_bincount(x, *, minlength):
"""
对批量数据进行单热编码(bincount 操作)。
参数:
x (Tensor): 输入张量,形状为 (batch_size, ...)。
minlength (int): 输出张量的最小长度。
返回:
Tensor: 单热编码后的张量,形状为 (batch_size, minlength)。
"""
batch, dtype, device = x.shape[0], x.dtype, x.device
target = torch.zeros(batch, minlength, dtype = dtype, device = device)
values = torch.ones_like(x)
target.scatter_add_(-1, x, values)
return target
def kmeans(
samples,
num_clusters,
num_iters = 10,
use_cosine_sim = False,
sample_fn = batched_sample_vectors,
all_reduce_fn = noop
):
"""
K-Means 聚类算法。
参数:
samples (Tensor): 输入样本张量。
num_clusters (int): 聚类数量。
num_iters (int, 可选): 迭代次数。默认值为 10。
use_cosine_sim (bool, 可选): 是否使用余弦相似度进行距离计算。默认值为 False。
sample_fn (callable, 可选): 采样函数。默认使用 batched_sample_vectors。
all_reduce_fn (callable, 可选): 全局归约函数。默认使用 noop。
返回:
Tuple[Tensor, Tensor]: 返回聚类中心张量和每个样本所属的聚类索引。
"""
# 获取码本数量、维度、数据类型和设备
num_codebooks, dim, dtype, device = samples.shape[0], samples.shape[-1], samples.dtype, samples.device
# 从样本中采样初始聚类中心
means = sample_fn(samples, num_clusters)
for _ in range(num_iters):
if use_cosine_sim:
# 如果使用余弦相似度,则计算样本与聚类中心的相似度
dists = samples @ rearrange(means, 'h n d -> h d n')
else:
# 否则,计算欧几里得距离
dists = -cdist(samples, means)
# 根据距离分配每个样本到最近的聚类中心
buckets = torch.argmax(dists, dim = -1)
# 计算每个聚类的样本数量
bins = batched_bincount(buckets, minlength = num_clusters)
# 进行全局归约(如果需要)
all_reduce_fn(bins)
# 标记没有样本的聚类
zero_mask = bins == 0
# 将没有样本的聚类的数量设为 1,避免除以零
bins_min_clamped = bins.masked_fill(zero_mask, 1)
# 初始化新的聚类中心张量
new_means = buckets.new_zeros(num_codebooks, num_clusters, dim, dtype = dtype)
# 累加每个聚类中的样本
new_means.scatter_add_(1, repeat(buckets, 'h n -> h n d', d = dim), samples)
# 计算新的聚类中心
new_means = new_means / rearrange(bins_min_clamped, '... -> ... 1')
# 进行全局归约(如果需要)
all_reduce_fn(new_means)
if use_cosine_sim:
# 如果使用余弦相似度,则对新的聚类中心进行 L2 归一化
new_means = l2norm(new_means)
# 更新聚类中心,如果某个聚类没有样本,则保留原来的聚类中心
means = torch.where(
rearrange(zero_mask, '... -> ... 1'),
means,
new_means
)
# 返回最终的聚类中心和每个样本所属的聚类索引
return means, bins
# rotation trick related
def efficient_rotation_trick_transform(u, q, e):
"""
4.2 in https://arxiv.org/abs/2410.06424
"""
"""
应用旋转技巧变换。
参数:
u (Tensor): 输入张量 u。
q (Tensor): 输入张量 q。
e (Tensor): 输入张量 e。
返回:
Tensor: 应用旋转技巧变换后的张量。
"""
# 重塑 e 的形状为 (batch_size, 1, dim)
e = rearrange(e, 'b d -> b 1 d')
# 计算 u + q 的 L2 归一化,并分离计算图
w = l2norm(u + q, dim = 1).detach()
return (
e -
2 * (e @ rearrange(w, 'b d -> b d 1') @ rearrange(w, 'b d -> b 1 d')) + # 计算旋转项
2 * (e @ rearrange(u, 'b d -> b d 1').detach() @ rearrange(q, 'b d -> b 1 d').detach()) # 计算缩放项
)
def rotate_to(src, tgt):
# rotation trick STE (https://arxiv.org/abs/2410.06424) to get gradients through VQ layer.
"""
应用旋转技巧(Rotation Trick)以在 VQ 层中传递梯度。
参数:
src (Tensor): 源张量,形状为 (batch_size, ..., d)。
tgt (Tensor): 目标张量,形状为 (batch_size, ..., d)。
返回:
Tensor: 应用旋转技巧后的张量。
"""
# 对 src 和 tgt 进行打包,并保存解包函数
src, inverse = pack_one(src, '* d') # src 的形状为 (batch_size, ..., d)
tgt, _ = pack_one(tgt, '* d') # tgt 的形状为 (batch_size, ..., d)
# 计算 src 和 tgt 的 L2 范数
norm_src = src.norm(dim = -1, keepdim = True) # src 的 L2 范数,形状为 (batch_size, ..., 1)
norm_tgt = tgt.norm(dim = -1, keepdim = True) # tgt 的 L2 范数,形状为 (batch_size, ..., 1)
# 应用旋转技巧变换
rotated_tgt = efficient_rotation_trick_transform(
safe_div(src, norm_src), # src 的单位向量
safe_div(tgt, norm_tgt), # tgt 的单位向量
src # src 本身
).squeeze() # 旋转后的 tgt,形状为 (batch_size, ..., d)
# 调整旋转后的 tgt 的范数,使其与 src 的范数成比例
rotated = rotated_tgt * safe_div(norm_tgt, norm_src).detach() # 旋转后的张量
# 返回解包后的旋转后的张量
return inverse(rotated)
# distributed helpers
@cache # 缓存函数结果
def is_distributed():
"""
判断当前环境是否为分布式环境。
返回:
bool: 如果是分布式环境且 world_size 大于 1,则返回 True;否则返回 False。
"""
return distributed.is_initialized() and distributed.get_world_size() > 1
# regularization losses
def orthogonal_loss_fn(t):
# eq (2) from https://arxiv.org/abs/2112.00384
"""
正交损失函数,用于正则化。
参数:
t (Tensor): 输入张量,形状为 (batch_size, ..., d)。
返回:
Tensor: 正交损失值。
"""
# 获取张量的前两个维度的大小
h, n = t.shape[:2]
# 对输入张量进行 L2 归一化,形状为 (batch_size, ..., d)
normed_codes = l2norm(t)
# 计算归一化后的张量的余弦相似度矩阵,形状为 (batch_size, ..., d, d)
cosine_sim = einsum('h i d, h j d -> h i j', normed_codes, normed_codes)
# 计算正交损失, 返回损失值
return (cosine_sim ** 2).sum() / (h * n ** 2) - (1 / n)
# distance types
class EuclideanCodebook(Module):
def __init__(
self,
dim,
codebook_size,
num_codebooks = 1,
kmeans_init = False,
kmeans_iters = 10,
sync_kmeans = True,
decay = 0.8,
eps = 1e-5,
threshold_ema_dead_code = 2,
reset_cluster_size = None,
use_ddp = False,
learnable_codebook = False,
gumbel_sample = gumbel_sample,
sample_codebook_temp = 1.,
ema_update = True,
manual_ema_update = False,
affine_param = False,
sync_affine_param = False,
affine_param_batch_decay = 0.99,
affine_param_codebook_decay = 0.9
):
"""
欧几里得码本,用于量化输入向量。
参数:
dim (int): 特征维度。
codebook_size (int): 码本中向量的数量。
num_codebooks (int, 可选): 码本的数量。默认值为 1。
kmeans_init (bool, 可选): 是否使用 K-Means 初始化码本。默认值为 False。
kmeans_iters (int, 可选): K-Means 聚类的迭代次数。默认值为 10。
sync_kmeans (bool, 可选): 是否在分布式环境中同步 K-Means。默认值为 True。
decay (float, 可选): 指数移动平均(EMA)的衰减因子。默认值为 0.8。
eps (float, 可选): 防止除以零的极小值。默认值为 1e-5。
threshold_ema_dead_code (int, 可选): EMA 死码的阈值。默认值为 2。
reset_cluster_size (int, 可选): 重置聚类大小的阈值。默认值为 threshold_ema_dead_code。
use_ddp (bool, 可选): 是否使用分布式数据并行(Distributed Data Parallel)。默认值为 False。
learnable_codebook (bool, 可选): 是否使码本可学习。默认值为 False。
gumbel_sample (callable, 可选): Gumbel 采样函数。默认使用 gumbel_sample 函数。
sample_codebook_temp (float, 可选): 采样码本的温度参数。默认值为 1.0。
ema_update (bool, 可选): 是否使用 EMA 更新。默认值为 True。
manual_ema_update (bool, 可选): 是否手动更新 EMA。默认值为 False。
affine_param (bool, 可选): 是否使用仿射参数。默认值为 False。
sync_affine_param (bool, 可选): 是否在分布式环境中同步仿射参数。默认值为 False。
affine_param_batch_decay (float, 可选): 仿射参数的批次衰减因子。默认值为 0.99。
affine_param_codebook_decay (float, 可选): 仿射参数的码本衰减因子。默认值为 0.9。
"""
super().__init__()
# 设置输入转换函数为恒等函数
self.transform_input = identity
# 指数移动平均(EMA)的衰减因子
self.decay = decay
# 是否使用 EMA 更新
self.ema_update = ema_update
# 是否手动更新 EMA
self.manual_ema_update = manual_ema_update
# 初始化码本嵌入
# 选择初始化函数
init_fn = uniform_init if not kmeans_init else torch.zeros
# 初始化码本嵌入
embed = init_fn(num_codebooks, codebook_size, dim)
# 码本中向量的数量
self.codebook_size = codebook_size
# 码本的数量
self.num_codebooks = num_codebooks
# K-Means 聚类的迭代次数
self.kmeans_iters = kmeans_iters
# 防止除以零的极小值
self.eps = eps
# EMA 死码的阈值
self.threshold_ema_dead_code = threshold_ema_dead_code
# 重置聚类大小的阈值
self.reset_cluster_size = default(reset_cluster_size, threshold_ema_dead_code)
# 确保 gumbel_sample 是一个可调用的函数
assert callable(gumbel_sample)
# Gumbel 采样函数
self.gumbel_sample = gumbel_sample
# 采样码本的温度参数
self.sample_codebook_temp = sample_codebook_temp
# 检查是否在分布式环境中使用 K-Means 初始化多个码本
assert not (use_ddp and num_codebooks > 1 and kmeans_init), 'kmeans init is not compatible with multiple codebooks in distributed environment for now'
# 选择采样函数
self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors
self.replace_sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors
# 选择全局归约函数
self.kmeans_all_reduce_fn = distributed.all_reduce if use_ddp and sync_kmeans else noop
self.all_reduce_fn = distributed.all_reduce if use_ddp else noop
# 注册缓冲区
self.register_buffer('initted', torch.Tensor([not kmeans_init]))
# 聚类大小
self.register_buffer('cluster_size', torch.ones(num_codebooks, codebook_size))
# 码本嵌入的平均值
self.register_buffer('embed_avg', embed.clone())
# 是否使码本可学习
self.learnable_codebook = learnable_codebook
if learnable_codebook:
# 使码本嵌入可学习
self.embed = nn.Parameter(embed)
else:
# 注册码本嵌入
self.register_buffer('embed', embed)
# affine related params(仿射参数相关)
# 是否使用仿射参数
self.affine_param = affine_param
# 是否在分布式环境中同步仿射参数
self.sync_affine_param = sync_affine_param
if not affine_param:
# 如果不使用仿射参数,则返回
return
# 仿射参数的批次衰减因子
self.affine_param_batch_decay = affine_param_batch_decay
# 仿射参数的码本衰减因子
self.affine_param_codebook_decay = affine_param_codebook_decay
# 批次均值
self.register_buffer('batch_mean', None)
# 批次方差
self.register_buffer('batch_variance', None)
# 码本均值是否需要初始化
self.register_buffer('codebook_mean_needs_init', torch.Tensor([True]))
# 码本均值
self.register_buffer('codebook_mean', torch.empty(num_codebooks, 1, dim))
# 码本方差是否需要初始化
self.register_buffer('codebook_variance_needs_init', torch.Tensor([True]))
# 码本方差
self.register_buffer('codebook_variance', torch.empty(num_codebooks, 1, dim))
@torch.jit.ignore
def init_embed_(self, data, mask = None):
"""
初始化码本嵌入。
参数:
data (Tensor): 输入数据。
mask (Optional[Tensor], 可选): 数据掩码。默认值为 None。
"""
if self.initted:
return
if exists(mask):
c = data.shape[0]
# 根据掩码重塑数据
data = rearrange(data[mask], '(c n) d -> c n d', c = c)
# 使用 K-Means 进行聚类
embed, cluster_size = kmeans(
data,
self.codebook_size,
self.kmeans_iters,
sample_fn = self.sample_fn,
all_reduce_fn = self.kmeans_all_reduce_fn
)
# 计算嵌入的总和
embed_sum = embed * rearrange(cluster_size, '... -> ... 1')
# 复制嵌入到嵌入缓冲区
self.embed.data.copy_(embed)
# 复制嵌入的总和到嵌入平均值缓冲区
self.embed_avg.data.copy_(embed_sum)
# 复制聚类大小到聚类大小缓冲区
self.cluster_size.data.copy_(cluster_size)
# 标记为已初始化
self.initted.data.copy_(torch.Tensor([True]))
@torch.jit.ignore
def update_with_decay(self, buffer_name, new_value, decay):
"""
使用衰减因子更新缓冲区中的值。
参数:
buffer_name (str): 缓冲区的名称。
new_value (Tensor): 新的值。
decay (float): 衰减因子。
"""
old_value = getattr(self, buffer_name)
# 获取是否需要初始化
needs_init = getattr(self, buffer_name + "_needs_init", False)
if needs_init:
# 如果需要初始化,则标记为已初始化
self.register_buffer(buffer_name + "_needs_init", torch.Tensor([False]))
if not exists(old_value) or needs_init:
# 如果旧值不存在或需要初始化,则注册新的值
self.register_buffer(buffer_name, new_value.detach())
return
# 计算新的值,使用衰减因子
value = old_value * decay + new_value.detach() * (1 - decay)
# 更新缓冲区中的值
self.register_buffer(buffer_name, value)
@torch.jit.ignore
def update_affine(self, data, embed, mask = None):
"""
更新仿射参数,包括批次均值和方差以及码本均值和方差。
参数:
data (Tensor): 输入数据。
embed (Tensor): 嵌入数据。
mask (Optional[Tensor], 可选): 数据掩码。默认值为 None。
"""
# 确保仿射参数已启用
assert self.affine_param
# 定义方差函数
var_fn = partial(torch.var, unbiased = False)
# calculate codebook mean and variance
# 计算码本均值和方差
# 重塑嵌入数据的形状
embed = rearrange(embed, 'h ... d -> h (...) d')
if self.training:
# 在训练模式下,使用 EMA 更新码本均值和方差
self.update_with_decay('codebook_mean', reduce(embed, 'h n d -> h 1 d', 'mean'), self.affine_param_codebook_decay)
self.update_with_decay('codebook_variance', reduce(embed, 'h n d -> h 1 d', var_fn), self.affine_param_codebook_decay)
# prepare batch data, which depends on whether it has masking
# 准备批次数据,根据是否使用掩码
# 重塑输入数据的形状
data = rearrange(data, 'h ... d -> h (...) d')
if exists(mask):
c = data.shape[0]
# 如果存在掩码,则应用掩码
data = rearrange(data[mask], '(c n) d -> c n d', c = c)
# calculate batch mean and variance
# 计算批次均值和方差
if not self.sync_affine_param:
# 如果不同步仿射参数,则直接使用 EMA 更新批次均值和方差
self.update_with_decay('batch_mean', reduce(data, 'h n d -> h 1 d', 'mean'), self.affine_param_batch_decay)
self.update_with_decay('batch_variance', reduce(data, 'h n d -> h 1 d', var_fn), self.affine_param_batch_decay)
return
# 获取向量数量、设备和数据类型
num_vectors, device, dtype = data.shape[-2], data.device, data.dtype
# number of vectors, for denominator
# 计算分布式均值
# 创建向量数量的张量
num_vectors = torch.tensor([num_vectors], device = device, dtype = dtype)
# 在分布式环境中进行归约
distributed.all_reduce(num_vectors)
# calculate distributed mean
# 计算批次数据的总和
batch_sum = reduce(data, 'h n d -> h 1 d', 'sum')
# 在分布式环境中进行归约
distributed.all_reduce(batch_sum)
# 计算批次均值
batch_mean = batch_sum / num_vectors
# 更新批次均值
self.update_with_decay('batch_mean', batch_mean, self.affine_param_batch_decay)
# calculate distributed variance
# 计算分布式方差
# 计算方差的分子部分
variance_numer = reduce((data - batch_mean) ** 2, 'h n d -> h 1 d', 'sum')
# 在分布式环境中进行归约
distributed.all_reduce(variance_numer)
# 计算批次方差
batch_variance = variance_numer / num_vectors
# 更新批次方差
self.update_with_decay('batch_variance', batch_variance, self.affine_param_batch_decay)
def replace(self, batch_samples, batch_mask):
"""
替换码本中的样本。
参数:
batch_samples (Tensor): 批次样本。
batch_mask (Tensor): 批次掩码。
"""
for ind, (samples, mask) in enumerate(zip(batch_samples, batch_mask)):
# 采样替换样本
sampled = self.replace_sample_fn(rearrange(samples, '... -> 1 ...'), mask.sum().item())
# 重塑采样后的样本
sampled = rearrange(sampled, '1 ... -> ...')
# 替换嵌入数据
self.embed.data[ind][mask] = sampled
# 重置聚类大小
self.cluster_size.data[ind][mask] = self.reset_cluster_size
# 更新嵌入平均值
self.embed_avg.data[ind][mask] = sampled * self.reset_cluster_size