1- using System ;
21using System . Collections . Generic ;
32using System . Linq ;
43using Unity . MLAgents . Inference . Utils ;
@@ -55,62 +54,26 @@ internal class DiscreteActionOutputApplier : TensorApplier.IApplier
5554 {
5655 readonly int [ ] m_ActionSize ;
5756 readonly Multinomial m_Multinomial ;
58- readonly ITensorAllocator m_Allocator ;
5957 readonly ActionSpec m_ActionSpec ;
58+ readonly int [ ] m_StartActionIndices ;
59+ readonly float [ ] m_CdfBuffer ;
60+
6061
6162 public DiscreteActionOutputApplier ( ActionSpec actionSpec , int seed , ITensorAllocator allocator )
6263 {
6364 m_ActionSize = actionSpec . BranchSizes ;
6465 m_Multinomial = new Multinomial ( seed ) ;
65- m_Allocator = allocator ;
6666 m_ActionSpec = actionSpec ;
67+ m_StartActionIndices = Utilities . CumSum ( m_ActionSize ) ;
68+
69+ // Scratch space for computing the cumulative distribution function.
70+ // In order to reuse it, make it the size of the largest branch.
71+ var largestBranch = Mathf . Max ( m_ActionSize ) ;
72+ m_CdfBuffer = new float [ largestBranch ] ;
6773 }
6874
6975 public void Apply ( TensorProxy tensorProxy , IList < int > actionIds , Dictionary < int , ActionBuffers > lastActions )
7076 {
71- //var tensorDataProbabilities = tensorProxy.Data as float[,];
72- var idActionPairList = actionIds as List < int > ?? actionIds . ToList ( ) ;
73- var batchSize = idActionPairList . Count ;
74- var actionValues = new float [ batchSize , m_ActionSize . Length ] ;
75- var startActionIndices = Utilities . CumSum ( m_ActionSize ) ;
76- for ( var actionIndex = 0 ; actionIndex < m_ActionSize . Length ; actionIndex ++ )
77- {
78- var nBranchAction = m_ActionSize [ actionIndex ] ;
79- var actionProbs = new TensorProxy ( )
80- {
81- valueType = TensorProxy . TensorType . FloatingPoint ,
82- shape = new long [ ] { batchSize , nBranchAction } ,
83- data = m_Allocator . Alloc ( new TensorShape ( batchSize , nBranchAction ) )
84- } ;
85-
86- for ( var batchIndex = 0 ; batchIndex < batchSize ; batchIndex ++ )
87- {
88- for ( var branchActionIndex = 0 ;
89- branchActionIndex < nBranchAction ;
90- branchActionIndex ++ )
91- {
92- actionProbs . data [ batchIndex , branchActionIndex ] =
93- tensorProxy . data [ batchIndex , startActionIndices [ actionIndex ] + branchActionIndex ] ;
94- }
95- }
96-
97- var outputTensor = new TensorProxy ( )
98- {
99- valueType = TensorProxy . TensorType . FloatingPoint ,
100- shape = new long [ ] { batchSize , 1 } ,
101- data = m_Allocator . Alloc ( new TensorShape ( batchSize , 1 ) )
102- } ;
103-
104- Eval ( actionProbs , outputTensor , m_Multinomial ) ;
105-
106- for ( var ii = 0 ; ii < batchSize ; ii ++ )
107- {
108- actionValues [ ii , actionIndex ] = outputTensor . data [ ii , 0 ] ;
109- }
110- actionProbs . data . Dispose ( ) ;
111- outputTensor . data . Dispose ( ) ;
112- }
113-
11477 var agentIndex = 0 ;
11578 for ( var i = 0 ; i < actionIds . Count ; i ++ )
11679 {
@@ -126,74 +89,38 @@ public void Apply(TensorProxy tensorProxy, IList<int> actionIds, Dictionary<int,
12689 var discreteBuffer = actionBuffer . DiscreteActions ;
12790 for ( var j = 0 ; j < m_ActionSize . Length ; j ++ )
12891 {
129- discreteBuffer [ j ] = ( int ) actionValues [ agentIndex , j ] ;
92+ ComputeCdf ( tensorProxy , agentIndex , m_StartActionIndices [ j ] , m_ActionSize [ j ] ) ;
93+ discreteBuffer [ j ] = m_Multinomial . Sample ( m_CdfBuffer , m_ActionSize [ j ] ) ;
13094 }
13195 }
13296 agentIndex ++ ;
13397 }
13498 }
13599
136100 /// <summary>
137- /// Draw samples from a multinomial distribution based on log-probabilities specified
138- /// in tensor src. The samples will be saved in the dst tensor.
101+ /// Compute the cumulative distribution function for a given agent's action
102+ /// given the log-probabilities.
103+ /// The results are stored in m_CdfBuffer, which is the size of the largest action's number of branches.
139104 /// </summary>
140- /// <param name="src">2-D tensor with shape batch_size x num_classes</param>
141- /// <param name="dst">Allocated tensor with size batch_size x num_samples</param>
142- /// <param name="multinomial">Multinomial object used to sample values</param>
143- /// <exception cref="NotImplementedException">
144- /// Multinomial doesn't support integer tensors
145- /// </exception>
146- /// <exception cref="ArgumentException">Issue with tensor shape or type</exception>
147- /// <exception cref="ArgumentNullException">
148- /// At least one of the tensors is not allocated
149- /// </exception>
150- public static void Eval ( TensorProxy src , TensorProxy dst , Multinomial multinomial )
105+ /// <param name="logProbs"></param>
106+ /// <param name="batch">Index of the agent being considered</param>
107+ /// <param name="channelOffset">Offset into the tensor's channel.</param>
108+ /// <param name="branchSize"></param>
109+ internal void ComputeCdf ( TensorProxy logProbs , int batch , int channelOffset , int branchSize )
151110 {
152- if ( src . DataType != typeof ( float ) )
111+ // Find the class maximum
112+ var maxProb = float . NegativeInfinity ;
113+ for ( var cls = 0 ; cls < branchSize ; ++ cls )
153114 {
154- throw new NotImplementedException ( "Only float tensors are currently supported" ) ;
115+ maxProb = Mathf . Max ( logProbs . data [ batch , cls + channelOffset ] , maxProb ) ;
155116 }
156117
157- if ( src . valueType != dst . valueType )
118+ // Sum the log probabilities and compute CDF
119+ var sumProb = 0.0f ;
120+ for ( var cls = 0 ; cls < branchSize ; ++ cls )
158121 {
159- throw new ArgumentException (
160- "Source and destination tensors have different types!" ) ;
161- }
162-
163- if ( src . data == null || dst . data == null )
164- {
165- throw new ArgumentNullException ( ) ;
166- }
167-
168- if ( src . data . batch != dst . data . batch )
169- {
170- throw new ArgumentException ( "Batch size for input and output data is different!" ) ;
171- }
172-
173- var cdf = new float [ src . data . channels ] ;
174-
175- for ( var batch = 0 ; batch < src . data . batch ; ++ batch )
176- {
177- // Find the class maximum
178- var maxProb = float . NegativeInfinity ;
179- for ( var cls = 0 ; cls < src . data . channels ; ++ cls )
180- {
181- maxProb = Mathf . Max ( src . data [ batch , cls ] , maxProb ) ;
182- }
183-
184- // Sum the log probabilities and compute CDF
185- var sumProb = 0.0f ;
186- for ( var cls = 0 ; cls < src . data . channels ; ++ cls )
187- {
188- sumProb += Mathf . Exp ( src . data [ batch , cls ] - maxProb ) ;
189- cdf [ cls ] = sumProb ;
190- }
191-
192- // Generate the samples
193- for ( var sample = 0 ; sample < dst . data . channels ; ++ sample )
194- {
195- dst . data [ batch , sample ] = multinomial . Sample ( cdf ) ;
196- }
122+ sumProb += Mathf . Exp ( logProbs . data [ batch , cls + channelOffset ] - maxProb ) ;
123+ m_CdfBuffer [ cls ] = sumProb ;
197124 }
198125 }
199126 }
0 commit comments