mirror of
https://github.com/google-ai-edge/LiteRT.git
synced 2026-01-09 06:21:22 +08:00
Minor code size reduction for select and condition ops
This change makes the code size smaller than it was before adding float16 support for these ops. PiperOrigin-RevId: 853370966
This commit is contained in:
parent
98abb02117
commit
5f7f029d39
@ -283,107 +283,12 @@ TfLiteStatus NotEqualEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus GreaterEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteStatus LessEval(TfLiteContext* context, TfLiteNode* node, int lhs,
|
||||
int rhs) {
|
||||
const TfLiteTensor* input1;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor1, &input1));
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, lhs, &input1));
|
||||
const TfLiteTensor* input2;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor2, &input2));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
bool requires_broadcast = !HaveSameShapes(input1, input2);
|
||||
switch (input1->type) {
|
||||
case kTfLiteFloat32:
|
||||
Comparison<float, reference_ops::GreaterFn>(input1, input2, output,
|
||||
requires_broadcast);
|
||||
break;
|
||||
case kTfLiteFloat16:
|
||||
Comparison<Eigen::half, reference_ops::GreaterFn>(input1, input2, output,
|
||||
requires_broadcast);
|
||||
break;
|
||||
case kTfLiteInt32:
|
||||
Comparison<int32_t, reference_ops::GreaterFn>(input1, input2, output,
|
||||
requires_broadcast);
|
||||
break;
|
||||
case kTfLiteInt64:
|
||||
Comparison<int64_t, reference_ops::GreaterFn>(input1, input2, output,
|
||||
requires_broadcast);
|
||||
break;
|
||||
case kTfLiteUInt8:
|
||||
ComparisonQuantized<uint8_t, reference_ops::GreaterFn>(
|
||||
input1, input2, output, requires_broadcast);
|
||||
break;
|
||||
case kTfLiteInt8:
|
||||
ComparisonQuantized<int8_t, reference_ops::GreaterFn>(
|
||||
input1, input2, output, requires_broadcast);
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"Does not support type %d, requires float|int|uint8",
|
||||
input1->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus GreaterEqualEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input1;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor1, &input1));
|
||||
const TfLiteTensor* input2;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor2, &input2));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
bool requires_broadcast = !HaveSameShapes(input1, input2);
|
||||
switch (input1->type) {
|
||||
case kTfLiteFloat32:
|
||||
Comparison<float, reference_ops::GreaterEqualFn>(input1, input2, output,
|
||||
requires_broadcast);
|
||||
break;
|
||||
case kTfLiteFloat16:
|
||||
Comparison<Eigen::half, reference_ops::GreaterEqualFn>(
|
||||
input1, input2, output, requires_broadcast);
|
||||
break;
|
||||
case kTfLiteInt16:
|
||||
Comparison<int16_t, reference_ops::GreaterEqualFn>(input1, input2, output,
|
||||
requires_broadcast);
|
||||
break;
|
||||
case kTfLiteInt32:
|
||||
Comparison<int32_t, reference_ops::GreaterEqualFn>(input1, input2, output,
|
||||
requires_broadcast);
|
||||
break;
|
||||
case kTfLiteInt64:
|
||||
Comparison<int64_t, reference_ops::GreaterEqualFn>(input1, input2, output,
|
||||
requires_broadcast);
|
||||
break;
|
||||
case kTfLiteUInt8:
|
||||
ComparisonQuantized<uint8_t, reference_ops::GreaterEqualFn>(
|
||||
input1, input2, output, requires_broadcast);
|
||||
break;
|
||||
case kTfLiteInt8:
|
||||
ComparisonQuantized<int8_t, reference_ops::GreaterEqualFn>(
|
||||
input1, input2, output, requires_broadcast);
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"Does not support type %d, requires float|int|uint8",
|
||||
input1->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus LessEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input1;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor1, &input1));
|
||||
const TfLiteTensor* input2;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor2, &input2));
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, rhs, &input2));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
@ -430,13 +335,12 @@ TfLiteStatus LessEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus LessEqualEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteStatus LessEqualEval(TfLiteContext* context, TfLiteNode* node, int lhs,
|
||||
int rhs) {
|
||||
const TfLiteTensor* input1;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor1, &input1));
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, lhs, &input1));
|
||||
const TfLiteTensor* input2;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kInputTensor2, &input2));
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, rhs, &input2));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
@ -450,6 +354,10 @@ TfLiteStatus LessEqualEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
Comparison<Eigen::half, reference_ops::LessEqualFn>(
|
||||
input1, input2, output, requires_broadcast);
|
||||
break;
|
||||
case kTfLiteInt16:
|
||||
Comparison<int16_t, reference_ops::LessEqualFn>(input1, input2, output,
|
||||
requires_broadcast);
|
||||
break;
|
||||
case kTfLiteInt32:
|
||||
Comparison<int32_t, reference_ops::LessEqualFn>(input1, input2, output,
|
||||
requires_broadcast);
|
||||
@ -475,6 +383,20 @@ TfLiteStatus LessEqualEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus LessEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return LessEval(context, node, kInputTensor1, kInputTensor2);
|
||||
}
|
||||
TfLiteStatus LessEqualEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return LessEqualEval(context, node, kInputTensor1, kInputTensor2);
|
||||
}
|
||||
|
||||
TfLiteStatus GreaterEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return LessEval(context, node, kInputTensor2, kInputTensor1);
|
||||
}
|
||||
TfLiteStatus GreaterEqualEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return LessEqualEval(context, node, kInputTensor2, kInputTensor1);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace comparisons
|
||||
|
||||
|
||||
@ -148,35 +148,31 @@ TfLiteStatus SelectEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
GetTensorShape(input_y), GetTensorData<type>(input_y), \
|
||||
GetTensorShape(output), GetTensorData<type>(output));
|
||||
|
||||
// Select is basically just a conditional copy, so we don't care what the type
|
||||
// of the values as long as the type is the correct size.
|
||||
#define TF_LITE_SWITCH(type, op) \
|
||||
switch (type) { \
|
||||
break; \
|
||||
case kTfLiteBool: \
|
||||
TF_LITE_SELECT(bool, op); \
|
||||
break; \
|
||||
case kTfLiteFloat32: \
|
||||
TF_LITE_SELECT(float, op); \
|
||||
break; \
|
||||
case kTfLiteFloat16: \
|
||||
TF_LITE_SELECT(half, op); \
|
||||
break; \
|
||||
case kTfLiteUInt8: \
|
||||
TF_LITE_SELECT(uint8_t, op); \
|
||||
break; \
|
||||
case kTfLiteInt8: \
|
||||
TF_LITE_SELECT(int8_t, op); \
|
||||
break; \
|
||||
case kTfLiteUInt32: \
|
||||
case kTfLiteInt32: \
|
||||
case kTfLiteFloat32: \
|
||||
TF_LITE_SELECT(uint32_t, op); \
|
||||
break; \
|
||||
case kTfLiteUInt16: \
|
||||
case kTfLiteInt16: \
|
||||
TF_LITE_SELECT(int16_t, op); \
|
||||
case kTfLiteFloat16: \
|
||||
case kTfLiteBFloat16: \
|
||||
TF_LITE_SELECT(uint16_t, op); \
|
||||
break; \
|
||||
case kTfLiteInt32: \
|
||||
TF_LITE_SELECT(int32_t, op); \
|
||||
case kTfLiteUInt8: \
|
||||
case kTfLiteInt8: \
|
||||
TF_LITE_SELECT(uint8_t, op); \
|
||||
break; \
|
||||
case kTfLiteUInt64: \
|
||||
case kTfLiteInt64: \
|
||||
TF_LITE_SELECT(int64_t, op); \
|
||||
TF_LITE_SELECT(uint64_t, op); \
|
||||
break; \
|
||||
default: \
|
||||
TF_LITE_KERNEL_LOG(context, \
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user