Minor code size reduction for select and condition ops

This change makes the code size smaller than it was before adding float16 support for these ops.

PiperOrigin-RevId: 853370966
This commit is contained in:
Dillon Sharlet 2026-01-07 12:34:31 -08:00 committed by Copybara-Service
parent 98abb02117
commit 5f7f029d39
2 changed files with 39 additions and 121 deletions

View File

@ -283,107 +283,12 @@ TfLiteStatus NotEqualEval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
TfLiteStatus GreaterEval(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus LessEval(TfLiteContext* context, TfLiteNode* node, int lhs,
int rhs) {
const TfLiteTensor* input1;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor1, &input1));
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, lhs, &input1));
const TfLiteTensor* input2;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor2, &input2));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
bool requires_broadcast = !HaveSameShapes(input1, input2);
switch (input1->type) {
case kTfLiteFloat32:
Comparison<float, reference_ops::GreaterFn>(input1, input2, output,
requires_broadcast);
break;
case kTfLiteFloat16:
Comparison<Eigen::half, reference_ops::GreaterFn>(input1, input2, output,
requires_broadcast);
break;
case kTfLiteInt32:
Comparison<int32_t, reference_ops::GreaterFn>(input1, input2, output,
requires_broadcast);
break;
case kTfLiteInt64:
Comparison<int64_t, reference_ops::GreaterFn>(input1, input2, output,
requires_broadcast);
break;
case kTfLiteUInt8:
ComparisonQuantized<uint8_t, reference_ops::GreaterFn>(
input1, input2, output, requires_broadcast);
break;
case kTfLiteInt8:
ComparisonQuantized<int8_t, reference_ops::GreaterFn>(
input1, input2, output, requires_broadcast);
break;
default:
TF_LITE_KERNEL_LOG(context,
"Does not support type %d, requires float|int|uint8",
input1->type);
return kTfLiteError;
}
return kTfLiteOk;
}
TfLiteStatus GreaterEqualEval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input1;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor1, &input1));
const TfLiteTensor* input2;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor2, &input2));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
bool requires_broadcast = !HaveSameShapes(input1, input2);
switch (input1->type) {
case kTfLiteFloat32:
Comparison<float, reference_ops::GreaterEqualFn>(input1, input2, output,
requires_broadcast);
break;
case kTfLiteFloat16:
Comparison<Eigen::half, reference_ops::GreaterEqualFn>(
input1, input2, output, requires_broadcast);
break;
case kTfLiteInt16:
Comparison<int16_t, reference_ops::GreaterEqualFn>(input1, input2, output,
requires_broadcast);
break;
case kTfLiteInt32:
Comparison<int32_t, reference_ops::GreaterEqualFn>(input1, input2, output,
requires_broadcast);
break;
case kTfLiteInt64:
Comparison<int64_t, reference_ops::GreaterEqualFn>(input1, input2, output,
requires_broadcast);
break;
case kTfLiteUInt8:
ComparisonQuantized<uint8_t, reference_ops::GreaterEqualFn>(
input1, input2, output, requires_broadcast);
break;
case kTfLiteInt8:
ComparisonQuantized<int8_t, reference_ops::GreaterEqualFn>(
input1, input2, output, requires_broadcast);
break;
default:
TF_LITE_KERNEL_LOG(context,
"Does not support type %d, requires float|int|uint8",
input1->type);
return kTfLiteError;
}
return kTfLiteOk;
}
TfLiteStatus LessEval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input1;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor1, &input1));
const TfLiteTensor* input2;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor2, &input2));
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, rhs, &input2));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
@ -430,13 +335,12 @@ TfLiteStatus LessEval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
TfLiteStatus LessEqualEval(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus LessEqualEval(TfLiteContext* context, TfLiteNode* node, int lhs,
int rhs) {
const TfLiteTensor* input1;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor1, &input1));
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, lhs, &input1));
const TfLiteTensor* input2;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor2, &input2));
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, rhs, &input2));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
@ -450,6 +354,10 @@ TfLiteStatus LessEqualEval(TfLiteContext* context, TfLiteNode* node) {
Comparison<Eigen::half, reference_ops::LessEqualFn>(
input1, input2, output, requires_broadcast);
break;
case kTfLiteInt16:
Comparison<int16_t, reference_ops::LessEqualFn>(input1, input2, output,
requires_broadcast);
break;
case kTfLiteInt32:
Comparison<int32_t, reference_ops::LessEqualFn>(input1, input2, output,
requires_broadcast);
@ -475,6 +383,20 @@ TfLiteStatus LessEqualEval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
TfLiteStatus LessEval(TfLiteContext* context, TfLiteNode* node) {
return LessEval(context, node, kInputTensor1, kInputTensor2);
}
TfLiteStatus LessEqualEval(TfLiteContext* context, TfLiteNode* node) {
return LessEqualEval(context, node, kInputTensor1, kInputTensor2);
}
TfLiteStatus GreaterEval(TfLiteContext* context, TfLiteNode* node) {
return LessEval(context, node, kInputTensor2, kInputTensor1);
}
TfLiteStatus GreaterEqualEval(TfLiteContext* context, TfLiteNode* node) {
return LessEqualEval(context, node, kInputTensor2, kInputTensor1);
}
} // namespace
} // namespace comparisons

View File

@ -148,35 +148,31 @@ TfLiteStatus SelectEval(TfLiteContext* context, TfLiteNode* node) {
GetTensorShape(input_y), GetTensorData<type>(input_y), \
GetTensorShape(output), GetTensorData<type>(output));
// Select is basically just a conditional copy, so we don't care what the type
// of the values as long as the type is the correct size.
#define TF_LITE_SWITCH(type, op) \
switch (type) { \
break; \
case kTfLiteBool: \
TF_LITE_SELECT(bool, op); \
break; \
case kTfLiteFloat32: \
TF_LITE_SELECT(float, op); \
break; \
case kTfLiteFloat16: \
TF_LITE_SELECT(half, op); \
break; \
case kTfLiteUInt8: \
TF_LITE_SELECT(uint8_t, op); \
break; \
case kTfLiteInt8: \
TF_LITE_SELECT(int8_t, op); \
break; \
case kTfLiteUInt32: \
case kTfLiteInt32: \
case kTfLiteFloat32: \
TF_LITE_SELECT(uint32_t, op); \
break; \
case kTfLiteUInt16: \
case kTfLiteInt16: \
TF_LITE_SELECT(int16_t, op); \
case kTfLiteFloat16: \
case kTfLiteBFloat16: \
TF_LITE_SELECT(uint16_t, op); \
break; \
case kTfLiteInt32: \
TF_LITE_SELECT(int32_t, op); \
case kTfLiteUInt8: \
case kTfLiteInt8: \
TF_LITE_SELECT(uint8_t, op); \
break; \
case kTfLiteUInt64: \
case kTfLiteInt64: \
TF_LITE_SELECT(int64_t, op); \
TF_LITE_SELECT(uint64_t, op); \
break; \
default: \
TF_LITE_KERNEL_LOG(context, \