mirror of
https://github.com/google-ai-edge/LiteRT.git
synced 2026-01-09 06:21:22 +08:00
Add float16 support to comparison ops
PiperOrigin-RevId: 853073492
This commit is contained in:
parent
347c6741ab
commit
63806acb01
@ -16,11 +16,11 @@ limitations under the License.
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#include "Eigen/Core" // from @eigen_archive
|
||||
#include "tflite/core/c/common.h"
|
||||
#include "tflite/kernels/internal/compatibility.h"
|
||||
#include "tflite/kernels/internal/quantization_util.h"
|
||||
#include "tflite/kernels/internal/reference/reference_ops.h"
|
||||
#include "tflite/kernels/internal/tensor.h"
|
||||
#include "tflite/kernels/internal/reference/string_comparisons.h"
|
||||
#include "tflite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tflite/kernels/internal/types.h"
|
||||
#include "tflite/kernels/kernel_util.h"
|
||||
@ -180,6 +180,10 @@ TfLiteStatus EqualEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
Comparison<float, reference_ops::EqualFn>(input1, input2, output,
|
||||
requires_broadcast);
|
||||
break;
|
||||
case kTfLiteFloat16:
|
||||
Comparison<Eigen::half, reference_ops::EqualFn>(input1, input2, output,
|
||||
requires_broadcast);
|
||||
break;
|
||||
case kTfLiteInt16:
|
||||
if (input1->quantization.type == kTfLiteNoQuantization) {
|
||||
Comparison<int16_t, reference_ops::EqualFn>(input1, input2, output,
|
||||
@ -239,6 +243,10 @@ TfLiteStatus NotEqualEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
Comparison<float, reference_ops::NotEqualFn>(input1, input2, output,
|
||||
requires_broadcast);
|
||||
break;
|
||||
case kTfLiteFloat16:
|
||||
Comparison<Eigen::half, reference_ops::NotEqualFn>(input1, input2, output,
|
||||
requires_broadcast);
|
||||
break;
|
||||
case kTfLiteInt32:
|
||||
Comparison<int32_t, reference_ops::NotEqualFn>(input1, input2, output,
|
||||
requires_broadcast);
|
||||
@ -291,6 +299,10 @@ TfLiteStatus GreaterEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
Comparison<float, reference_ops::GreaterFn>(input1, input2, output,
|
||||
requires_broadcast);
|
||||
break;
|
||||
case kTfLiteFloat16:
|
||||
Comparison<Eigen::half, reference_ops::GreaterFn>(input1, input2, output,
|
||||
requires_broadcast);
|
||||
break;
|
||||
case kTfLiteInt32:
|
||||
Comparison<int32_t, reference_ops::GreaterFn>(input1, input2, output,
|
||||
requires_broadcast);
|
||||
@ -332,6 +344,10 @@ TfLiteStatus GreaterEqualEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
Comparison<float, reference_ops::GreaterEqualFn>(input1, input2, output,
|
||||
requires_broadcast);
|
||||
break;
|
||||
case kTfLiteFloat16:
|
||||
Comparison<Eigen::half, reference_ops::GreaterEqualFn>(
|
||||
input1, input2, output, requires_broadcast);
|
||||
break;
|
||||
case kTfLiteInt16:
|
||||
Comparison<int16_t, reference_ops::GreaterEqualFn>(input1, input2, output,
|
||||
requires_broadcast);
|
||||
@ -430,6 +446,10 @@ TfLiteStatus LessEqualEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
Comparison<float, reference_ops::LessEqualFn>(input1, input2, output,
|
||||
requires_broadcast);
|
||||
break;
|
||||
case kTfLiteFloat16:
|
||||
Comparison<Eigen::half, reference_ops::LessEqualFn>(
|
||||
input1, input2, output, requires_broadcast);
|
||||
break;
|
||||
case kTfLiteInt32:
|
||||
Comparison<int32_t, reference_ops::LessEqualFn>(input1, input2, output,
|
||||
requires_broadcast);
|
||||
|
||||
@ -125,6 +125,17 @@ TEST(ComparisonsTest, EqualFloat) {
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
|
||||
}
|
||||
|
||||
TEST(ComparisonsTest, EqualFloat16) {
|
||||
ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT16,
|
||||
BuiltinOperator_EQUAL);
|
||||
model.PopulateTensor<half>(model.input1(), {0.1, 0.9, 0.7, 0.3});
|
||||
model.PopulateTensor<half>(model.input2(), {0.1, 0.2, 0.6, 0.5});
|
||||
ASSERT_EQ(model.Invoke(), kTfLiteOk);
|
||||
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, false));
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
|
||||
}
|
||||
|
||||
TEST(ComparisonsTest, EqualInt) {
|
||||
ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32,
|
||||
BuiltinOperator_EQUAL);
|
||||
@ -220,6 +231,17 @@ TEST(ComparisonsTest, NotEqualFloat) {
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
|
||||
}
|
||||
|
||||
TEST(ComparisonsTest, NotEqualFloat16) {
|
||||
ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT16,
|
||||
BuiltinOperator_NOT_EQUAL);
|
||||
model.PopulateTensor<half>(model.input1(), {0.1, 0.9, 0.7, 0.3});
|
||||
model.PopulateTensor<half>(model.input2(), {0.1, 0.2, 0.6, 0.5});
|
||||
ASSERT_EQ(model.Invoke(), kTfLiteOk);
|
||||
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, true));
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
|
||||
}
|
||||
|
||||
TEST(ComparisonsTest, NotEqualInt) {
|
||||
ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32,
|
||||
BuiltinOperator_NOT_EQUAL);
|
||||
@ -293,6 +315,17 @@ TEST(ComparisonsTest, GreaterFloat) {
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
|
||||
}
|
||||
|
||||
TEST(ComparisonsTest, GreaterFloat16) {
|
||||
ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT16,
|
||||
BuiltinOperator_GREATER);
|
||||
model.PopulateTensor<half>(model.input1(), {0.1f, 0.9f, 0.7f, 0.3f});
|
||||
model.PopulateTensor<half>(model.input2(), {0.1f, 0.2f, 0.6f, 0.5f});
|
||||
ASSERT_EQ(model.Invoke(), kTfLiteOk);
|
||||
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, false));
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
|
||||
}
|
||||
|
||||
TEST(ComparisonsTest, GreaterInt) {
|
||||
ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32,
|
||||
BuiltinOperator_GREATER);
|
||||
@ -338,6 +371,17 @@ TEST(ComparisonsTest, GreaterEqualFloat) {
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
|
||||
}
|
||||
|
||||
TEST(ComparisonsTest, GreaterEqualFloat16) {
|
||||
ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT16,
|
||||
BuiltinOperator_GREATER_EQUAL);
|
||||
model.PopulateTensor<half>(model.input1(), {0.1f, 0.9f, 0.7f, 0.3f});
|
||||
model.PopulateTensor<half>(model.input2(), {0.1f, 0.2f, 0.6f, 0.5f});
|
||||
ASSERT_EQ(model.Invoke(), kTfLiteOk);
|
||||
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAre(true, true, true, false));
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
|
||||
}
|
||||
|
||||
TEST(ComparisonsTest, GreaterEqualInt) {
|
||||
ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32,
|
||||
BuiltinOperator_GREATER_EQUAL);
|
||||
@ -478,6 +522,19 @@ TEST(ComparisonsTest, LessEqualFloat) {
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
|
||||
}
|
||||
|
||||
TEST(ComparisonsTest, LessEqualFloat16) {
|
||||
ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT16,
|
||||
BuiltinOperator_LESS_EQUAL);
|
||||
model.PopulateTensor<half>(model.input1(),
|
||||
{half(0.1f), half(0.9f), half(0.7f), half(0.3f)});
|
||||
model.PopulateTensor<half>(model.input2(),
|
||||
{half(0.1f), half(0.2f), half(0.6f), half(0.5f)});
|
||||
ASSERT_EQ(model.Invoke(), kTfLiteOk);
|
||||
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, true));
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
|
||||
}
|
||||
|
||||
TEST(ComparisonsTest, LessEqualInt) {
|
||||
ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32,
|
||||
BuiltinOperator_LESS_EQUAL);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user