1- using Microsoft . VisualStudio . TestPlatform . Utilities ;
2- using Microsoft . VisualStudio . TestTools . UnitTesting ;
1+ using Microsoft . VisualStudio . TestTools . UnitTesting ;
32using System ;
4- using System . Collections . Generic ;
5- using System . Linq ;
6- using System . Text ;
7- using System . Threading . Tasks ;
8- using System . Xml . Linq ;
9- using Tensorflow . Operations ;
10- using static Tensorflow . Binding ;
11- using static Tensorflow . KerasApi ;
12- using Tensorflow . NumPy ;
13- using Microsoft . VisualBasic ;
14- using static HDF . PInvoke . H5T ;
15- using Tensorflow . Keras . UnitTest . Helpers ;
3+ using Tensorflow ;
164using Tensorflow . Keras . Optimizers ;
5+ using Tensorflow . NumPy ;
6+ using static Tensorflow . KerasApi ;
177
18- namespace Tensorflow . Keras . UnitTest
8+ namespace TensorFlowNET . Keras . UnitTest
199{
2010 [ TestClass ]
2111 public class MultiInputModelTest
2212 {
2313 [ TestMethod ]
24- public void SimpleModel ( )
14+ public void LeNetModel ( )
2515 {
2616 var inputs = keras . Input ( ( 28 , 28 , 1 ) ) ;
2717 var conv1 = keras . layers . Conv2D ( 16 , ( 3 , 3 ) , activation : "relu" , padding : "same" ) . Apply ( inputs ) ;
@@ -40,7 +30,7 @@ public void SimpleModel()
4030 var concat = keras . layers . Concatenate ( ) . Apply ( ( flat1 , flat1_2 ) ) ;
4131 var dense1 = keras . layers . Dense ( 512 , activation : "relu" ) . Apply ( concat ) ;
4232 var dense2 = keras . layers . Dense ( 128 , activation : "relu" ) . Apply ( dense1 ) ;
43- var dense3 = keras . layers . Dense ( 10 , activation : "relu" ) . Apply ( dense2 ) ;
33+ var dense3 = keras . layers . Dense ( 10 , activation : "relu" ) . Apply ( dense2 ) ;
4434 var output = keras . layers . Softmax ( - 1 ) . Apply ( dense3 ) ;
4535
4636 var model = keras . Model ( ( inputs , inputs_2 ) , output ) ;
@@ -52,7 +42,7 @@ public void SimpleModel()
5242 {
5343 TrainDir = "mnist" ,
5444 OneHot = false ,
55- ValidationSize = 59000 ,
45+ ValidationSize = 59900 ,
5646 } ) . Result ;
5747
5848 var loss = keras . losses . SparseCategoricalCrossentropy ( ) ;
@@ -64,6 +54,11 @@ public void SimpleModel()
6454
6555 var x = new NDArray [ ] { x1 , x2 } ;
6656 model . fit ( x , dataset . Train . Labels , batch_size : 8 , epochs : 3 ) ;
57+
58+ x1 = np . ones ( ( 1 , 28 , 28 , 1 ) , TF_DataType . TF_FLOAT ) ;
59+ x2 = np . zeros ( ( 1 , 28 , 28 , 1 ) , TF_DataType . TF_FLOAT ) ;
60+ var pred = model . predict ( ( x1 , x2 ) ) ;
61+ Console . WriteLine ( pred ) ;
6762 }
6863 }
6964}
0 commit comments