Add float16 support to comparison ops

PiperOrigin-RevId: 853073492
This commit is contained in:
Dillon Sharlet 2026-01-06 21:50:10 -08:00 committed by Copybara-Service
parent 347c6741ab
commit 63806acb01
2 changed files with 79 additions and 2 deletions

View File

@ -16,11 +16,11 @@ limitations under the License.
#include <stdint.h>
#include "Eigen/Core" // from @eigen_archive
#include "tflite/core/c/common.h"
#include "tflite/kernels/internal/compatibility.h"
#include "tflite/kernels/internal/quantization_util.h"
#include "tflite/kernels/internal/reference/reference_ops.h"
#include "tflite/kernels/internal/tensor.h"
#include "tflite/kernels/internal/reference/string_comparisons.h"
#include "tflite/kernels/internal/tensor_ctypes.h"
#include "tflite/kernels/internal/types.h"
#include "tflite/kernels/kernel_util.h"
@ -180,6 +180,10 @@ TfLiteStatus EqualEval(TfLiteContext* context, TfLiteNode* node) {
Comparison<float, reference_ops::EqualFn>(input1, input2, output,
requires_broadcast);
break;
case kTfLiteFloat16:
Comparison<Eigen::half, reference_ops::EqualFn>(input1, input2, output,
requires_broadcast);
break;
case kTfLiteInt16:
if (input1->quantization.type == kTfLiteNoQuantization) {
Comparison<int16_t, reference_ops::EqualFn>(input1, input2, output,
@ -239,6 +243,10 @@ TfLiteStatus NotEqualEval(TfLiteContext* context, TfLiteNode* node) {
Comparison<float, reference_ops::NotEqualFn>(input1, input2, output,
requires_broadcast);
break;
case kTfLiteFloat16:
Comparison<Eigen::half, reference_ops::NotEqualFn>(input1, input2, output,
requires_broadcast);
break;
case kTfLiteInt32:
Comparison<int32_t, reference_ops::NotEqualFn>(input1, input2, output,
requires_broadcast);
@ -291,6 +299,10 @@ TfLiteStatus GreaterEval(TfLiteContext* context, TfLiteNode* node) {
Comparison<float, reference_ops::GreaterFn>(input1, input2, output,
requires_broadcast);
break;
case kTfLiteFloat16:
Comparison<Eigen::half, reference_ops::GreaterFn>(input1, input2, output,
requires_broadcast);
break;
case kTfLiteInt32:
Comparison<int32_t, reference_ops::GreaterFn>(input1, input2, output,
requires_broadcast);
@ -332,6 +344,10 @@ TfLiteStatus GreaterEqualEval(TfLiteContext* context, TfLiteNode* node) {
Comparison<float, reference_ops::GreaterEqualFn>(input1, input2, output,
requires_broadcast);
break;
case kTfLiteFloat16:
Comparison<Eigen::half, reference_ops::GreaterEqualFn>(
input1, input2, output, requires_broadcast);
break;
case kTfLiteInt16:
Comparison<int16_t, reference_ops::GreaterEqualFn>(input1, input2, output,
requires_broadcast);
@ -430,6 +446,10 @@ TfLiteStatus LessEqualEval(TfLiteContext* context, TfLiteNode* node) {
Comparison<float, reference_ops::LessEqualFn>(input1, input2, output,
requires_broadcast);
break;
case kTfLiteFloat16:
Comparison<Eigen::half, reference_ops::LessEqualFn>(
input1, input2, output, requires_broadcast);
break;
case kTfLiteInt32:
Comparison<int32_t, reference_ops::LessEqualFn>(input1, input2, output,
requires_broadcast);

View File

@ -125,6 +125,17 @@ TEST(ComparisonsTest, EqualFloat) {
EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}
TEST(ComparisonsTest, EqualFloat16) {
ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT16,
BuiltinOperator_EQUAL);
model.PopulateTensor<half>(model.input1(), {0.1, 0.9, 0.7, 0.3});
model.PopulateTensor<half>(model.input2(), {0.1, 0.2, 0.6, 0.5});
ASSERT_EQ(model.Invoke(), kTfLiteOk);
EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, false));
EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}
TEST(ComparisonsTest, EqualInt) {
ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32,
BuiltinOperator_EQUAL);
@ -220,6 +231,17 @@ TEST(ComparisonsTest, NotEqualFloat) {
EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}
TEST(ComparisonsTest, NotEqualFloat16) {
ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT16,
BuiltinOperator_NOT_EQUAL);
model.PopulateTensor<half>(model.input1(), {0.1, 0.9, 0.7, 0.3});
model.PopulateTensor<half>(model.input2(), {0.1, 0.2, 0.6, 0.5});
ASSERT_EQ(model.Invoke(), kTfLiteOk);
EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, true));
EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}
TEST(ComparisonsTest, NotEqualInt) {
ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32,
BuiltinOperator_NOT_EQUAL);
@ -293,6 +315,17 @@ TEST(ComparisonsTest, GreaterFloat) {
EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}
TEST(ComparisonsTest, GreaterFloat16) {
ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT16,
BuiltinOperator_GREATER);
model.PopulateTensor<half>(model.input1(), {0.1f, 0.9f, 0.7f, 0.3f});
model.PopulateTensor<half>(model.input2(), {0.1f, 0.2f, 0.6f, 0.5f});
ASSERT_EQ(model.Invoke(), kTfLiteOk);
EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, false));
EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}
TEST(ComparisonsTest, GreaterInt) {
ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32,
BuiltinOperator_GREATER);
@ -338,6 +371,17 @@ TEST(ComparisonsTest, GreaterEqualFloat) {
EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}
TEST(ComparisonsTest, GreaterEqualFloat16) {
ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT16,
BuiltinOperator_GREATER_EQUAL);
model.PopulateTensor<half>(model.input1(), {0.1f, 0.9f, 0.7f, 0.3f});
model.PopulateTensor<half>(model.input2(), {0.1f, 0.2f, 0.6f, 0.5f});
ASSERT_EQ(model.Invoke(), kTfLiteOk);
EXPECT_THAT(model.GetOutput(), ElementsAre(true, true, true, false));
EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}
TEST(ComparisonsTest, GreaterEqualInt) {
ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32,
BuiltinOperator_GREATER_EQUAL);
@ -478,6 +522,19 @@ TEST(ComparisonsTest, LessEqualFloat) {
EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}
TEST(ComparisonsTest, LessEqualFloat16) {
ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT16,
BuiltinOperator_LESS_EQUAL);
model.PopulateTensor<half>(model.input1(),
{half(0.1f), half(0.9f), half(0.7f), half(0.3f)});
model.PopulateTensor<half>(model.input2(),
{half(0.1f), half(0.2f), half(0.6f), half(0.5f)});
ASSERT_EQ(model.Invoke(), kTfLiteOk);
EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, true));
EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}
TEST(ComparisonsTest, LessEqualInt) {
ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32,
BuiltinOperator_LESS_EQUAL);