diff --git a/src/TensorFlowNET.Core/APIs/tf.image.cs b/src/TensorFlowNET.Core/APIs/tf.image.cs
index 9230b50dc..ac9cbc60d 100644
--- a/src/TensorFlowNET.Core/APIs/tf.image.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.image.cs
@@ -14,6 +14,10 @@ You may obtain a copy of the License at
limitations under the License.
******************************************************************************/
+using OneOf.Types;
+using System;
+using System.Buffers.Text;
+using Tensorflow.Contexts;
using static Tensorflow.Binding;
namespace Tensorflow
@@ -162,17 +166,108 @@ public Tensor ssim_multiscale(Tensor img1, Tensor img2, float max_val, float[] p
public Tensor sobel_edges(Tensor image)
=> image_ops_impl.sobel_edges(image);
- public Tensor decode_jpeg(Tensor contents,
- int channels = 0,
- int ratio = 1,
- bool fancy_upscaling = true,
- bool try_recover_truncated = false,
- int acceptable_fraction = 1,
- string dct_method = "",
- string name = null)
- => gen_image_ops.decode_jpeg(contents, channels: channels, ratio: ratio,
- fancy_upscaling: fancy_upscaling, try_recover_truncated: try_recover_truncated,
- acceptable_fraction: acceptable_fraction, dct_method: dct_method);
+ ///
+ /// Adjust contrast of RGB or grayscale images.
+ ///
+ /// Images to adjust. At least 3-D.
+ ///
+ /// A float multiplier for adjusting contrast.
+ /// The contrast-adjusted image or images.
+ public Tensor adjust_contrast(Tensor images, float contrast_factor, string name = null)
+ => gen_image_ops.adjust_contrastv2(images, contrast_factor, name);
+
+ ///
+ /// Adjust hue of RGB images.
+ ///
+ /// RGB image or images. The size of the last dimension must be 3.
+ /// float. How much to add to the hue channel.
+ /// A name for this operation (optional).
+ /// Adjusted image(s), same shape and DType as `image`.
+ /// if `delta` is not in the interval of `[-1, 1]`.
+ public Tensor adjust_hue(Tensor images, float delta, string name = null)
+ {
+ if (tf.Context.executing_eagerly())
+ {
+ if (delta < -1f || delta > 1f)
+ throw new ValueError("delta must be in the interval [-1, 1]");
+ }
+ return gen_image_ops.adjust_hue(images, delta, name: name);
+ }
+
+ ///
+ /// Adjust saturation of RGB images.
+ ///
+ /// RGB image or images. The size of the last dimension must be 3.
+ /// float. Factor to multiply the saturation by.
+ /// A name for this operation (optional).
+ /// Adjusted image(s), same shape and DType as `image`.
+ public Tensor adjust_saturation(Tensor image, float saturation_factor, string name = null)
+ => gen_image_ops.adjust_saturation(image, saturation_factor, name);
+
+ ///
+ /// Greedily selects a subset of bounding boxes in descending order of score.
+ ///
+ ///
+ /// A 4-D float `Tensor` of shape `[batch_size, num_boxes, q, 4]`. If `q`
+ /// is 1 then same boxes are used for all classes otherwise, if `q` is equal
+ /// to number of classes, class-specific boxes are used.
+ ///
+ ///
+ /// A 3-D float `Tensor` of shape `[batch_size, num_boxes, num_classes]`
+ /// representing a single score corresponding to each box(each row of boxes).
+ ///
+ ///
+ /// A scalar integer `Tensor` representing the
+ /// maximum number of boxes to be selected by non-max suppression per class
+ ///
+ ///
+ /// A int32 scalar representing maximum number of boxes retained
+ /// over all classes.Note that setting this value to a large number may
+ /// result in OOM error depending on the system workload.
+ ///
+ ///
+ /// A float representing the threshold for deciding whether boxes
+ /// overlap too much with respect to IOU.
+ ///
+ ///
+ /// A float representing the threshold for deciding when to
+ /// remove boxes based on score.
+ ///
+ ///
+ /// If false, the output nmsed boxes, scores and classes are
+ /// padded/clipped to `max_total_size`. If true, the output nmsed boxes, scores and classes are padded to be of length `max_size_per_class`*`num_classes`,
+ /// unless it exceeds `max_total_size` in which case it is clipped to `max_total_size`. Defaults to false.
+ ///
+ ///
+ /// If true, the coordinates of output nmsed boxes will be clipped
+ /// to[0, 1]. If false, output the box coordinates as it is. Defaults to true.
+ ///
+ ///
+ /// 'nmsed_boxes': A [batch_size, max_detections, 4] float32 tensor containing the non-max suppressed boxes.
+ /// 'nmsed_scores': A [batch_size, max_detections] float32 tensor containing the scores for the boxes.
+ /// 'nmsed_classes': A [batch_size, max_detections] float32 tensor containing the class for boxes.
+ /// 'valid_detections': A [batch_size] int32 tensor indicating the number of
+ /// valid detections per batch item. Only the top valid_detections[i] entries
+ /// in nms_boxes[i], nms_scores[i] and nms_class[i] are valid. The rest of the
+ /// entries are zero paddings.
+ ///
+ public (Tensor, Tensor, Tensor, Tensor) combined_non_max_suppression(
+ Tensor boxes,
+ Tensor scores,
+ int max_output_size_per_class,
+ int max_total_size,
+ float iou_threshold,
+ float score_threshold,
+ bool pad_per_class = false,
+ bool clip_boxes = true)
+ {
+ var iou_threshold_t = ops.convert_to_tensor(iou_threshold, TF_DataType.TF_FLOAT, name: "iou_threshold");
+ var score_threshold_t = ops.convert_to_tensor(score_threshold, TF_DataType.TF_FLOAT, name: "score_threshold");
+ var max_total_size_t = ops.convert_to_tensor(max_total_size);
+ var max_output_size_per_class_t = ops.convert_to_tensor(max_output_size_per_class);
+ return gen_image_ops.combined_non_max_suppression(boxes, scores, max_output_size_per_class_t, max_total_size_t,
+ iou_threshold_t, score_threshold_t, pad_per_class, clip_boxes);
+ }
///
/// Extracts crops from the input image tensor and resizes them using bilinear sampling or nearest neighbor sampling (possibly with aspect ratio change) to a common output size specified by crop_size. This is more general than the crop_to_bounding_box op which extracts a fixed size slice from the input image and does not allow resizing or aspect ratio change.
@@ -187,7 +282,19 @@ public Tensor decode_jpeg(Tensor contents,
/// A name for the operation (optional).
/// A 4-D tensor of shape [num_boxes, crop_height, crop_width, depth].
public Tensor crop_and_resize(Tensor image, Tensor boxes, Tensor box_ind, Tensor crop_size, string method = "bilinear", float extrapolation_value = 0f, string name = null) =>
- image_ops_impl.crop_and_resize(image, boxes, box_ind, crop_size, method, extrapolation_value, name);
+ gen_image_ops.crop_and_resize(image, boxes, box_ind, crop_size, method, extrapolation_value, name);
+
+ public Tensor decode_jpeg(Tensor contents,
+ int channels = 0,
+ int ratio = 1,
+ bool fancy_upscaling = true,
+ bool try_recover_truncated = false,
+ int acceptable_fraction = 1,
+ string dct_method = "",
+ string name = null)
+ => gen_image_ops.decode_jpeg(contents, channels: channels, ratio: ratio,
+ fancy_upscaling: fancy_upscaling, try_recover_truncated: try_recover_truncated,
+ acceptable_fraction: acceptable_fraction, dct_method: dct_method);
public Tensor extract_glimpse(Tensor input, Tensor size, Tensor offsets, bool centered = true, bool normalized = true,
bool uniform_noise = true, string name = null)
diff --git a/src/TensorFlowNET.Core/Operations/gen_image_ops.cs b/src/TensorFlowNET.Core/Operations/gen_image_ops.cs
index 9240b5905..cbe661ae5 100644
--- a/src/TensorFlowNET.Core/Operations/gen_image_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/gen_image_ops.cs
@@ -16,18 +16,312 @@ limitations under the License.
using System;
using System.Linq;
+using Tensorflow.Eager;
using static Tensorflow.Binding;
+using Tensorflow.Exceptions;
+using Tensorflow.Contexts;
+using System.Xml.Linq;
+using Google.Protobuf;
namespace Tensorflow
{
public class gen_image_ops
{
+ public static Tensor adjust_contrastv2(Tensor images, Tensor contrast_factor, string name = null)
+ {
+ var _ctx = tf.Context;
+ if (_ctx.executing_eagerly())
+ {
+ try
+ {
+ var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "AdjustContrastv2", name) {
+ args = new object[] { images, contrast_factor }, attrs = new Dictionary() { } });
+ return _fast_path_result[0];
+ }
+ catch (NotOkStatusException ex)
+ {
+ throw ex;
+ }
+ catch (Exception)
+ {
+ }
+ try
+ {
+ return adjust_contrastv2_eager_fallback(images, contrast_factor, name: name, ctx: _ctx);
+ }
+ catch (Exception)
+ {
+ }
+ }
+ Dictionary keywords = new();
+ keywords["images"] = images;
+ keywords["contrast_factor"] = contrast_factor;
+ var _op = tf.OpDefLib._apply_op_helper("AdjustContrastv2", name, keywords);
+ var _result = _op.outputs;
+ if (_execute.must_record_gradient())
+ {
+ object[] _attrs = new object[] { "T", _op._get_attr_type("T") };
+ _execute.record_gradient("AdjustContrastv2", _op.inputs, _attrs, _result);
+ }
+ return _result[0];
+ }
+ public static Tensor adjust_contrastv2(Tensor image, float contrast_factor, string name = null)
+ {
+ return adjust_contrastv2(image, tf.convert_to_tensor(contrast_factor), name: name);
+ }
+
+ public static Tensor adjust_contrastv2_eager_fallback(Tensor images, Tensor contrast_factor, string name, Context ctx)
+ {
+ Tensor[] _inputs_flat = new Tensor[] { images, contrast_factor};
+ object[] _attrs = new object[] { "T", images.dtype };
+ var _result = _execute.execute("AdjustContrastv2", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name);
+ if (_execute.must_record_gradient())
+ {
+ _execute.record_gradient("AdjustContrastv2", _inputs_flat, _attrs, _result);
+ }
+ return _result[0];
+ }
+
+ public static Tensor adjust_hue(Tensor images, Tensor delta, string name = null)
+ {
+ var _ctx = tf.Context;
+ if (_ctx.executing_eagerly())
+ {
+ try
+ {
+ var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "AdjustHue", name) {
+ args = new object[] { images, delta }, attrs = new Dictionary() { } });
+ return _fast_path_result[0];
+ }
+ catch (NotOkStatusException ex)
+ {
+ throw ex;
+ }
+ catch (Exception)
+ {
+ }
+ try
+ {
+ return adjust_hue_eager_fallback(images, delta, name: name, ctx: _ctx);
+ }
+ catch (Exception)
+ {
+ }
+ }
+ Dictionary keywords = new();
+ keywords["images"] = images;
+ keywords["delta"] = delta;
+ var _op = tf.OpDefLib._apply_op_helper("AdjustHue", name, keywords);
+ var _result = _op.outputs;
+ if (_execute.must_record_gradient())
+ {
+ object[] _attrs = new object[] { "T", _op._get_attr_type("T") };
+ _execute.record_gradient("AdjustHue", _op.inputs, _attrs, _result);
+ }
+ return _result[0];
+ }
+
+ public static Tensor adjust_hue(Tensor images, float delta, string name = null)
+ => adjust_hue(images, delta, name: name);
+
+ public static Tensor adjust_hue_eager_fallback(Tensor images, Tensor delta, string name, Context ctx)
+ {
+ Tensor[] _inputs_flat = new Tensor[] { images, delta};
+ object[] _attrs = new object[] { "T", images.dtype };
+ var _result = _execute.execute("AdjustHue", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name);
+ if (_execute.must_record_gradient())
+ {
+ _execute.record_gradient("AdjustHue", _inputs_flat, _attrs, _result);
+ }
+ return _result[0];
+ }
+
+ public static Tensor adjust_saturation(Tensor images, Tensor scale, string name = null)
+ {
+ var _ctx = tf.Context;
+ if (_ctx.executing_eagerly())
+ {
+ try
+ {
+ var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "AdjustSaturation", name)
+ {
+ args = new object[] { images, scale },
+ attrs = new Dictionary() { }
+ });
+ return _fast_path_result[0];
+ }
+ catch (NotOkStatusException ex)
+ {
+ throw ex;
+ }
+ catch (Exception)
+ {
+ }
+ try
+ {
+ return adjust_hue_eager_fallback(images, scale, name: name, ctx: _ctx);
+ }
+ catch (Exception)
+ {
+ }
+ }
+ Dictionary keywords = new();
+ keywords["images"] = images;
+ keywords["scale"] = scale;
+ var _op = tf.OpDefLib._apply_op_helper("AdjustSaturation", name, keywords);
+ var _result = _op.outputs;
+ if (_execute.must_record_gradient())
+ {
+ object[] _attrs = new object[] { "T", _op._get_attr_type("T") };
+ _execute.record_gradient("AdjustSaturation", _op.inputs, _attrs, _result);
+ }
+ return _result[0];
+ }
+
+ public static Tensor adjust_saturation(Tensor images, float scale, string name = null)
+ => adjust_saturation(images, ops.convert_to_tensor(scale), name: name);
+
+ public static Tensor adjust_saturation_eager_fallback(Tensor images, Tensor scale, string name, Context ctx)
+ {
+ Tensor[] _inputs_flat = new Tensor[] { images, scale };
+ object[] _attrs = new object[] { "T", images.dtype };
+ var _result = _execute.execute("AdjustSaturation", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name);
+ if (_execute.must_record_gradient())
+ {
+ _execute.record_gradient("AdjustSaturation", _inputs_flat, _attrs, _result);
+ }
+ return _result[0];
+ }
+
public static (Tensor, Tensor, Tensor, Tensor) combined_non_max_suppression(Tensor boxes, Tensor scores, Tensor max_output_size_per_class, Tensor max_total_size,
- Tensor iou_threshold, Tensor score_threshold, bool pad_per_class, bool clip_boxes)
+ Tensor iou_threshold, Tensor score_threshold, bool pad_per_class = false, bool clip_boxes = true, string name = null)
{
- throw new NotImplementedException("combined_non_max_suppression");
+ var _ctx = tf.Context;
+ if (_ctx.executing_eagerly())
+ {
+ try
+ {
+ var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "CombinedNonMaxSuppression", name){
+ args = new object[] {
+ boxes, scores, max_output_size_per_class, max_total_size, iou_threshold, score_threshold,
+ "pad_per_class", pad_per_class, "clip_boxes", clip_boxes},
+ attrs = new Dictionary() { }});
+ return (_fast_path_result[0], _fast_path_result[1], _fast_path_result[2], _fast_path_result[3]);
+ }
+ catch (NotOkStatusException ex)
+ {
+ throw ex;
+ }
+ catch (Exception)
+ {
+ }
+ try
+ {
+ return combined_non_max_suppression_eager_fallback(
+ boxes, scores, max_output_size_per_class, max_total_size, iou_threshold,
+ score_threshold, pad_per_class, clip_boxes, name, ctx: _ctx);
+ }
+ catch (Exception)
+ {
+ }
+ }
+ Dictionary keywords = new();
+ keywords["boxes"] = boxes;
+ keywords["scores"] = scores;
+ keywords["max_output_size_per_class"] = max_output_size_per_class;
+ keywords["max_total_size"] = max_total_size;
+ keywords["iou_threshold"] = iou_threshold;
+ keywords["score_threshold"] = score_threshold;
+ keywords["pad_per_class"] = pad_per_class;
+ keywords["clip_boxes"] = clip_boxes;
+
+ var _op = tf.OpDefLib._apply_op_helper("CombinedNonMaxSuppression", name, keywords);
+ var _result = _op.outputs;
+ if (_execute.must_record_gradient())
+ {
+ object[] _attrs = new object[] { "pad_per_class", _op._get_attr_type("pad_per_class") ,"clip_boxes", _op._get_attr_type("clip_boxes")};
+ _execute.record_gradient("CombinedNonMaxSuppression", _op.inputs, _attrs, _result);
+ }
+ return (_result[0], _result[1], _result[2], _result[3]);
}
+ public static (Tensor, Tensor, Tensor, Tensor) combined_non_max_suppression_eager_fallback(Tensor boxes, Tensor scores, Tensor max_output_size_per_class, Tensor max_total_size,
+ Tensor iou_threshold, Tensor score_threshold, bool pad_per_class, bool clip_boxes, string name, Context ctx)
+ {
+ Tensor[] _inputs_flat = new Tensor[] { boxes, scores, max_output_size_per_class, max_total_size, iou_threshold, score_threshold };
+ object[] _attrs = new object[] { "pad_per_class", pad_per_class, "clip_boxes", clip_boxes };
+ var _result = _execute.execute("CombinedNonMaxSuppression", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name);
+ if (_execute.must_record_gradient())
+ {
+ _execute.record_gradient("CombinedNonMaxSuppression", _inputs_flat, _attrs, _result);
+ }
+ return (_result[0], _result[1], _result[2], _result[3]);
+ }
+
+ public static Tensor crop_and_resize(Tensor image, Tensor boxes, Tensor box_ind, Tensor crop_size, string method = "bilinear", float extrapolation_value = 0f, string name = null)
+ {
+ var _ctx = tf.Context;
+ if (_ctx.executing_eagerly())
+ {
+ try
+ {
+ var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "CropAndResize", name) {
+ args = new object[] {
+ image, boxes, box_ind, crop_size, "method", method, "extrapolation_value", extrapolation_value }, attrs = new Dictionary() { } });
+ return _fast_path_result[0];
+ }
+ catch (NotOkStatusException ex)
+ {
+ throw ex;
+ }
+ catch (Exception)
+ {
+ }
+ try
+ {
+ return crop_and_resize_eager_fallback(
+ image, boxes, box_ind, crop_size, method: method, extrapolation_value: extrapolation_value, name: name, ctx: _ctx);
+ }
+ catch (Exception)
+ {
+ }
+ }
+ Dictionary keywords = new();
+ keywords["image"] = image;
+ keywords["boxes"] = boxes;
+ keywords["box_ind"] = box_ind;
+ keywords["crop_size"] = crop_size;
+ keywords["method"] = method;
+ keywords["extrapolation_value"] = extrapolation_value;
+ var _op = tf.OpDefLib._apply_op_helper("CropAndResize", name, keywords);
+ var _result = _op.outputs;
+ if (_execute.must_record_gradient())
+ {
+ object[] _attrs = new object[] { "T", _op._get_attr_type("T") ,"method", _op._get_attr_type("method") ,
+ "extrapolation_value", _op.get_attr("extrapolation_value")};
+ _execute.record_gradient("CropAndResize", _op.inputs, _attrs, _result);
+ }
+ return _result[0];
+ }
+
+ public static Tensor crop_and_resize_eager_fallback(Tensor image, Tensor boxes, Tensor box_ind, Tensor crop_size, string method, float extrapolation_value, string name, Context ctx)
+ {
+ if (method is null)
+ method = "bilinear";
+ //var method_cpmpat = ByteString.CopyFromUtf8(method ?? string.Empty);
+ //var extrapolation_value_float = (float)extrapolation_value;
+
+ Tensor[] _inputs_flat = new Tensor[] { image, boxes, box_ind, crop_size, tf.convert_to_tensor(method), tf.convert_to_tensor(extrapolation_value) };
+ object[] _attrs = new object[] { "T", image.dtype };
+ var _result = _execute.execute("CropAndResize", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name);
+ if (_execute.must_record_gradient())
+ {
+ _execute.record_gradient("CropAndResize", _inputs_flat, _attrs, _result);
+ }
+ return _result[0];
+ }
+
+
public static Tensor convert_image_dtype(Tensor image, TF_DataType dtype, bool saturate = false, string name = null)
{
if (dtype == image.dtype)
diff --git a/test/TensorFlowNET.Graph.UnitTest/ImageTest.cs b/test/TensorFlowNET.Graph.UnitTest/ImageTest.cs
index c42445cf1..d671b6096 100644
--- a/test/TensorFlowNET.Graph.UnitTest/ImageTest.cs
+++ b/test/TensorFlowNET.Graph.UnitTest/ImageTest.cs
@@ -3,6 +3,7 @@
using System.Linq;
using Tensorflow;
using static Tensorflow.Binding;
+using System;
namespace TensorFlowNET.UnitTest
{
@@ -22,13 +23,86 @@ public void Initialize()
contents = tf.io.read_file(imgPath);
}
+ [TestMethod]
+ public void adjust_contrast()
+ {
+ var input = np.array(0f, 1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f);
+ var image = tf.reshape(input, new int[] { 3, 3, 1 });
+
+ var init = tf.global_variables_initializer();
+ var sess = tf.Session();
+ sess.run(init);
+ var adjust_contrast = tf.image.adjust_contrast(image, 2.0f);
+ var result = sess.run(adjust_contrast);
+ var res = np.array(-4f, -2f, 0f, 2f, 4f, 6f, 8f, 10f, 12f).reshape((3,3,1));
+ Assert.AreEqual(result.numpy(), res);
+ }
+
+ [Ignore]
+ [TestMethod]
+ public void adjust_hue()
+ {
+ var image = tf.constant(new int[] {1,2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,15,16,17,18});
+ image = tf.reshape(image, new int[] { 3, 2, 3 });
+ var adjusted_image = tf.image.adjust_hue(image, 0.2f);
+ var res = tf.constant(new int[] {2,1,3, 4, 5, 6,8,7,9,11,10,12,14,13,15,17,16,18});
+ res = tf.reshape(res,(3,2,3));
+ Assert.AreEqual(adjusted_image, res);
+ }
+
+ [TestMethod]
+ public void combined_non_max_suppression()
+ {
+ var boxesX = tf.constant(new float[,] { { 200, 100, 150, 100 }, { 220, 120, 150, 100 }, { 190, 110, 150, 100 }, { 210, 112, 150, 100 } });
+ var boxes1 = tf.reshape(boxesX, (1, 4, 1, 4));
+ var scoresX = tf.constant(new float[,] { { 0.2f, 0.7f, 0.1f }, { 0.1f, 0.8f, 0.1f }, { 0.3f, 0.6f, 0.1f }, { 0.05f, 0.9f, 0.05f } });
+ var scores1 = tf.reshape(scoresX, (1, 4, 3));
+
+ var init = tf.global_variables_initializer();
+ var sess = tf.Session();
+ sess.run(init);
+
+ var (boxes, scores, classes, valid_detections) = tf.image.combined_non_max_suppression(boxes1, scores1, 10, 10, 0.5f, 0.2f, clip_boxes: false);
+ var result = sess.run((boxes, scores, classes, valid_detections));
+
+ var boxes_gt = tf.constant(new float[,] { { 210f, 112f, 150f, 100f }, { 200f, 100f, 150f, 100f }, { 190f, 110f, 150f, 100f },
+ { 0f, 0f, 0f, 0f},{ 0f, 0f, 0f, 0f},{ 0f, 0f, 0f, 0f},{ 0f, 0f, 0f , 0f},{ 0f, 0f, 0f, 0f},{ 0f , 0f, 0f, 0f},{ 0f, 0f, 0f, 0f} });
+ boxes_gt = tf.reshape(boxes_gt, (1, 10, 4));
+ Assert.AreEqual(result.Item1.numpy(), boxes_gt.numpy());
+ var scores_gt = tf.constant(new float[,] { { 0.9f, 0.7f, 0.3f, 0f, 0f, 0f, 0f, 0f, 0f, 0f } });
+ scores_gt = tf.reshape(scores_gt, (1, 10));
+ Assert.AreEqual(result.Item2.numpy(), scores_gt.numpy());
+ var classes_gt = tf.constant(new float[,] { { 1f, 1f, 0f, 0f, 0f, 0f, 0f, 0f, 0f, 0f } });
+ classes_gt = tf.reshape(classes_gt, (1, 10));
+ Assert.AreEqual(result.Item3.numpy(), classes_gt.numpy());
+ var valid_detections_gt = tf.constant(new int[,] { { 3 } });
+ valid_detections_gt = tf.reshape(valid_detections_gt, (1));
+ Assert.AreEqual(result.Item4.numpy(), valid_detections_gt.numpy());
+ }
+
+ [TestMethod]
+ public void crop_and_resize()
+ {
+ int BATCH_SIZE = 1;
+ int NUM_BOXES = 5;
+ int IMAGE_HEIGHT = 256;
+ int IMAGE_WIDTH = 256;
+ int CHANNELS = 3;
+ var crop_size = tf.constant(new int[] { 24, 24 });
+ var image = tf.random.uniform((BATCH_SIZE, IMAGE_HEIGHT, IMAGE_WIDTH, CHANNELS));
+ var boxes = tf.random.uniform((NUM_BOXES, 4));
+ var box_ind = tf.random.uniform((NUM_BOXES), minval: 0, maxval: BATCH_SIZE, dtype: TF_DataType.TF_INT32);
+ var output = tf.image.crop_and_resize(image, boxes, box_ind, crop_size);
+ Assert.AreEqual((5,24,24,3), output.shape);
+ }
+
[TestMethod]
public void decode_image()
{
var img = tf.image.decode_image(contents);
Assert.AreEqual(img.name, "decode_image/DecodeImage:0");
}
-
+
[TestMethod]
public void resize_image()
{