1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ def log_1_min_a (a ):
6
+ return torch .log (1 - a .exp () + 1e-40 )
7
+
8
+ def log_add_exp (a , b ):
9
+ maximum = torch .max (a , b )
10
+ return maximum + torch .log (torch .exp (a - maximum ) + torch .exp (b - maximum ))
11
+
12
+ def extract (a , t , x_shape ):
13
+ b , * _ = t .shape
14
+ out = a .gather (- 1 , t )
15
+ return out .reshape (b , * ((1 ,) * (len (x_shape ) - 1 )))
16
+
17
+ def log_categorical (log_x_start , log_prob ):
18
+ return (log_x_start .exp () * log_prob ).sum (dim = 1 )
19
+
20
+ def index_to_log_onehot (x , num_classes ):
21
+ assert x .max ().item () < num_classes , \
22
+ f'Error: { x .max ().item ()} >= { num_classes } '
23
+ x_onehot = F .one_hot (x , num_classes )
24
+ permute_order = (0 , - 1 ) + tuple (range (1 , len (x .size ())))
25
+ x_onehot = x_onehot .permute (permute_order )
26
+ log_x = torch .log (x_onehot .float ().clamp (min = 1e-30 ))
27
+ return log_x
28
+
29
+ def log_onehot_to_index (log_x ):
30
+ return log_x .argmax (1 )
31
+
32
+ def alpha_schedule (time_step , N = 100 , att_1 = 0.99999 , att_T = 0.000009 , ctt_1 = 0.000009 , ctt_T = 0.99999 ):
33
+ att = np .arange (0 , time_step )/ (time_step - 1 )* (att_T - att_1 ) + att_1
34
+ att = np .concatenate (([1 ], att ))
35
+ at = att [1 :]/ att [:- 1 ]
36
+ ctt = np .arange (0 , time_step )/ (time_step - 1 )* (ctt_T - ctt_1 ) + ctt_1
37
+ ctt = np .concatenate (([0 ], ctt ))
38
+ one_minus_ctt = 1 - ctt
39
+ one_minus_ct = one_minus_ctt [1 :] / one_minus_ctt [:- 1 ]
40
+ ct = 1 - one_minus_ct
41
+ bt = (1 - at - ct )/ N
42
+ att = np .concatenate ((att [1 :], [1 ]))
43
+ ctt = np .concatenate ((ctt [1 :], [0 ]))
44
+ btt = (1 - att - ctt )/ N
45
+ return at , bt , ct , att , btt , ctt
46
+
47
+
48
+ class OrigScheduler :
49
+ def __init__ (self , * , num_classes , content_seq_len , num_timesteps = 100 ):
50
+ self .num_timesteps = num_timesteps
51
+ self .num_classes = num_classes
52
+ self .content_seq_len = content_seq_len
53
+
54
+ at , bt , ct , att , btt , ctt = alpha_schedule (self .num_timesteps , N = self .num_classes - 1 )
55
+
56
+ at = torch .tensor (at .astype ('float64' ))
57
+ bt = torch .tensor (bt .astype ('float64' ))
58
+ ct = torch .tensor (ct .astype ('float64' ))
59
+ log_at = torch .log (at )
60
+ log_bt = torch .log (bt )
61
+ log_ct = torch .log (ct )
62
+ att = torch .tensor (att .astype ('float64' ))
63
+ btt = torch .tensor (btt .astype ('float64' ))
64
+ ctt = torch .tensor (ctt .astype ('float64' ))
65
+ log_cumprod_at = torch .log (att )
66
+ log_cumprod_bt = torch .log (btt )
67
+ log_cumprod_ct = torch .log (ctt )
68
+
69
+ log_1_min_ct = log_1_min_a (log_ct )
70
+ log_1_min_cumprod_ct = log_1_min_a (log_cumprod_ct )
71
+
72
+ assert log_add_exp (log_ct , log_1_min_ct ).abs ().sum ().item () < 1.e-5
73
+ assert log_add_exp (log_cumprod_ct , log_1_min_cumprod_ct ).abs ().sum ().item () < 1.e-5
74
+
75
+ # Convert to float32 and register buffers.
76
+ self .log_at = log_at .float ()
77
+ self .log_bt = log_bt .float ()
78
+ self .log_ct = log_ct .float ()
79
+ self .log_cumprod_at = log_cumprod_at .float ()
80
+ self .log_cumprod_bt = log_cumprod_bt .float ()
81
+ self .log_cumprod_ct = log_cumprod_ct .float ()
82
+ self .log_1_min_ct = log_1_min_ct .float ()
83
+ self .log_1_min_cumprod_ct = log_1_min_cumprod_ct .float ()
84
+
85
+
86
+
87
+ def q_posterior (self , log_x_start , log_x_t , t ): # p_theta(xt_1|xt) = sum(q(xt-1|xt,x0')*p(x0'))
88
+ # notice that log_x_t is onehot
89
+ assert t .min ().item () >= 0 and t .max ().item () < self .num_timesteps
90
+ batch_size = log_x_start .size ()[0 ]
91
+ onehot_x_t = log_onehot_to_index (log_x_t )
92
+ mask = (onehot_x_t == self .num_classes - 1 ).unsqueeze (1 )
93
+ log_one_vector = torch .zeros (batch_size , 1 , 1 ).type_as (log_x_t )
94
+ log_zero_vector = torch .log (log_one_vector + 1.0e-30 ).expand (- 1 , - 1 , self .content_seq_len )
95
+
96
+ log_qt = self .q_pred (log_x_t , t ) # q(xt|x0)
97
+ # log_qt = torch.cat((log_qt[:,:-1,:], log_zero_vector), dim=1)
98
+ log_qt = log_qt [:,:- 1 ,:]
99
+ log_cumprod_ct = extract (self .log_cumprod_ct , t , log_x_start .shape ) # ct~
100
+ ct_cumprod_vector = log_cumprod_ct .expand (- 1 , self .num_classes - 1 , - 1 )
101
+ # ct_cumprod_vector = torch.cat((ct_cumprod_vector, log_one_vector), dim=1)
102
+ log_qt = (~ mask )* log_qt + mask * ct_cumprod_vector
103
+
104
+
105
+ log_qt_one_timestep = self .q_pred_one_timestep (log_x_t , t ) # q(xt|xt_1)
106
+ log_qt_one_timestep = torch .cat ((log_qt_one_timestep [:,:- 1 ,:], log_zero_vector ), dim = 1 )
107
+ log_ct = extract (self .log_ct , t , log_x_start .shape ) # ct
108
+ ct_vector = log_ct .expand (- 1 , self .num_classes - 1 , - 1 )
109
+ ct_vector = torch .cat ((ct_vector , log_one_vector ), dim = 1 )
110
+ log_qt_one_timestep = (~ mask )* log_qt_one_timestep + mask * ct_vector
111
+
112
+ # log_x_start = torch.cat((log_x_start, log_zero_vector), dim=1)
113
+ # q = log_x_start - log_qt
114
+ q = log_x_start [:,:- 1 ,:] - log_qt
115
+ q = torch .cat ((q , log_zero_vector ), dim = 1 )
116
+ q_log_sum_exp = torch .logsumexp (q , dim = 1 , keepdim = True )
117
+ q = q - q_log_sum_exp
118
+ log_EV_xtmin_given_xt_given_xstart = self .q_pred (q , t - 1 ) + log_qt_one_timestep + q_log_sum_exp
119
+ return torch .clamp (log_EV_xtmin_given_xt_given_xstart , - 70 , 0 )
120
+
121
+
122
+ def q_pred_one_timestep (self , log_x_t , t ): # q(xt|xt_1)
123
+ log_at = extract (self .log_at , t , log_x_t .shape ) # at
124
+ log_bt = extract (self .log_bt , t , log_x_t .shape ) # bt
125
+ log_ct = extract (self .log_ct , t , log_x_t .shape ) # ct
126
+ log_1_min_ct = extract (self .log_1_min_ct , t , log_x_t .shape ) # 1-ct
127
+
128
+ log_probs = torch .cat (
129
+ [
130
+ log_add_exp (log_x_t [:,:- 1 ,:]+ log_at , log_bt ),
131
+ log_add_exp (log_x_t [:, - 1 :, :] + log_1_min_ct , log_ct )
132
+ ],
133
+ dim = 1
134
+ )
135
+
136
+ return log_probs
137
+
138
+ def q_pred (self , log_x_start , t ): # q(xt|x0)
139
+ # log_x_start can be onehot or not
140
+ t = (t + (self .num_timesteps + 1 ))% (self .num_timesteps + 1 )
141
+ log_cumprod_at = extract (self .log_cumprod_at , t , log_x_start .shape ) # at~
142
+ log_cumprod_bt = extract (self .log_cumprod_bt , t , log_x_start .shape ) # bt~
143
+ log_cumprod_ct = extract (self .log_cumprod_ct , t , log_x_start .shape ) # ct~
144
+ log_1_min_cumprod_ct = extract (self .log_1_min_cumprod_ct , t , log_x_start .shape ) # 1-ct~
145
+
146
+
147
+ log_probs = torch .cat (
148
+ [
149
+ log_add_exp (log_x_start [:,:- 1 ,:]+ log_cumprod_at , log_cumprod_bt ),
150
+ log_add_exp (log_x_start [:,- 1 :,:]+ log_1_min_cumprod_ct , log_cumprod_ct )
151
+ ],
152
+ dim = 1
153
+ )
154
+
155
+ return log_probs
0 commit comments