$ Skip to main content
Overview
0% · ... remaining
0%
Edge OCR Practice: Native Deployment of PP-OCRv5 on Android
$ cat mobile/ppocrv5-android.md

# Edge OCR Practice: Native Deployment of PP-OCRv5 on Android

Author:
Date: December 29, 2025 at 20:17
Reading Time: 23 min read
mobile/ppocrv5-android.md

This page has been translated by AI. In case of any discrepancies, please refer to the original post.

Notes

This blog post:

  • Cover: Generated based on Google Nano Banana 2, no copyright reserved.
  • Project Source Code: Open-sourced on GitHub, please visit PPOCRv5-Android to access it.

Disclaimer:

The author (Fleey) is not a professional in the AI field; this project is driven purely by personal interest. Please forgive any omissions or errors in the text, and feel free to provide corrections!

Introduction

In 2024, Google rebranded TensorFlow Lite as LiteRT. This was not just a branding exercise but marked a paradigm shift in on-device AI from “mobile-first” to “edge-first” 1. In this context, OCR (Optical Character Recognition), as one of the most practical on-device AI applications, is undergoing a silent revolution.

Baidu’s PaddleOCR team released PP-OCRv5 in 2025, a unified OCR model supporting multiple languages including Simplified Chinese, Traditional Chinese, English, and Japanese 2. Its mobile version is only about 70MB, yet it can recognize 18,383 characters within a single model. Behind this number lies the collaborative work of two deep neural networks: detection and recognition.

But the problem is: PP-OCRv5 is trained on the PaddlePaddle framework, while the most mature inference engine on Android devices is LiteRT. How do we bridge this gap?

Let’s start with model conversion and gradually unveil the engineering behind on-device OCR.

flowchart TB
subgraph E2E["End-to-End OCR Pipeline"]
direction TB
subgraph Input["Input"]
IMG[Original Image<br/>Any Size]
end
subgraph Detection["Text Detection - DBNet"]
DET_PRE[Preprocessing<br/>Resize 640x640<br/>ImageNet Normalize]
DET_INF[DBNet Inference<br/>~45ms GPU]
DET_POST[Post-processing<br/>Binarization - Contours - Rotated Rect]
end
subgraph Recognition["Text Recognition - SVTRv2"]
REC_CROP[Perspective Transform Crop<br/>48xW Adaptive Width]
REC_INF[SVTRv2 Inference<br/>~15ms/line GPU]
REC_CTC[CTC Decoding<br/>Merge Duplicates + Remove Blanks]
end
subgraph Output["Output"]
RES[OCR Results<br/>Text + Confidence + Position]
end
end
IMG --> DET_PRE --> DET_INF --> DET_POST
DET_POST -->|N Text Boxes| REC_CROP
REC_CROP --> REC_INF --> REC_CTC --> RES

Model Conversion: The Long Journey from PaddlePaddle to TFLite

Fragmentation in deep learning frameworks is a major pain point in the industry. PyTorch, TensorFlow, PaddlePaddle, ONNX—each framework has its own model format and operator implementations. While ONNX (Open Neural Network Exchange) attempts to be a universal intermediate representation, reality is often harsher than the ideal.

The model conversion path for PP-OCRv5 is as follows:

flowchart LR
subgraph PaddlePaddle["PaddlePaddle Framework"]
PM[inference.json<br/>inference.pdiparams]
end
subgraph ONNX["ONNX Intermediate"]
OM[model.onnx<br/>opset 14]
end
subgraph Optimization["Graph Optimization"]
GS[onnx-graphsurgeon<br/>Operator Decomposition]
end
subgraph TFLite["LiteRT Format"]
TM[model.tflite<br/>FP16 Quantized]
end
PM -->|paddle2onnx| OM
OM -->|HardSigmoid Decomposition<br/>Resize Mode Modification| GS
GS -->|onnx2tf| TM

This path seems simple but hides several nuances.

Pitfall 1: Operator Compatibility in paddle2onnx

paddle2onnx is the official model conversion tool provided by PaddlePaddle. Theoretically, it can convert PaddlePaddle models to ONNX format. However, PP-OCRv5 uses some special operators whose mappings in ONNX are not one-to-one.

Terminal window
paddle2onnx --model_dir PP-OCRv5_mobile_det \
--model_filename inference.json \
--params_filename inference.pdiparams \
--save_file ocr_det_v5.onnx \
--opset_version 14

A key detail here: the PP-OCRv5 model filename is inference.json rather than the traditional inference.pdmodel. This is a change in the model format of newer PaddlePaddle versions that many developers overlook 3.

Pitfall 2: HardSigmoid and GPU Compatibility

The converted ONNX model contains the HardSigmoid operator. Mathematically, this operator is defined as:

HardSigmoid(x)=max(0,min(1,αx+β))\text{HardSigmoid}(x) = \max(0, \min(1, \alpha x + \beta))

where α=0.2\alpha = 0.2 and β=0.5\beta = 0.5.

The problem is that LiteRT’s GPU Delegate does not support HardSigmoid. When a model contains unsupported operators, the GPU Delegate falls back to the CPU for that entire subgraph, leading to significant performance loss.

The solution is to decompose HardSigmoid into basic operators. Using the onnx-graphsurgeon library, we can perform “surgery” at the computational graph level:

import onnx_graphsurgeon as gs
import numpy as np
def decompose_hardsigmoid(graph: gs.Graph) -> gs.Graph:
"""
Decompose HardSigmoid into GPU-friendly basic operators
HardSigmoid(x) = max(0, min(1, alpha*x + beta))
Decomposed into: Mul -> Add -> Clip
"""
for node in graph.nodes:
if node.op == "HardSigmoid":
# Get HardSigmoid parameters
alpha = node.attrs.get("alpha", 0.2)
beta = node.attrs.get("beta", 0.5)
input_tensor = node.inputs[0]
output_tensor = node.outputs[0]
# Create constant tensors
alpha_const = gs.Constant(
name=f"{node.name}_alpha",
values=np.array([alpha], dtype=np.float32)
)
beta_const = gs.Constant(
name=f"{node.name}_beta",
values=np.array([beta], dtype=np.float32)
)
# Create intermediate variables
mul_out = gs.Variable(name=f"{node.name}_mul_out")
add_out = gs.Variable(name=f"{node.name}_add_out")
# Build decomposed subgraph: x -> Mul(alpha) -> Add(beta) -> Clip(0,1)
mul_node = gs.Node(
op="Mul",
inputs=[input_tensor, alpha_const],
outputs=[mul_out]
)
add_node = gs.Node(
op="Add",
inputs=[mul_out, beta_const],
outputs=[add_out]
)
clip_node = gs.Node(
op="Clip",
inputs=[add_out],
outputs=[output_tensor],
attrs={"min": 0.0, "max": 1.0}
)
# Replace original node
graph.nodes.remove(node)
graph.nodes.extend([mul_node, add_node, clip_node])
graph.cleanup().toposort()
return graph

The key to this decomposition is that Mul, Add, and Clip are all operators fully supported by the LiteRT GPU Delegate. After decomposition, the entire subgraph can be executed continuously on the GPU, avoiding the overhead of CPU-GPU data transfers.

TIP

Why not modify the model training code directly? Because the gradient calculation for HardSigmoid during training differs from Clip. Decomposition should only occur during the inference stage to maintain numerical stability during training.

Pitfall 3: Coordinate Transformation Mode of the Resize Operator

The ONNX Resize operator has a coordinate_transformation_mode attribute, which determines how output coordinates are mapped to input coordinates. PP-OCRv5 uses the half_pixel mode, but LiteRT GPU Delegate has limited support for this mode.

Changing it to asymmetric mode provides better GPU compatibility:

for node in graph.nodes:
if node.op == "Resize":
node.attrs["coordinate_transformation_mode"] = "asymmetric"

WARNING

This modification may cause minor numerical differences. In practical testing, the impact of this difference on OCR accuracy is negligible, but it may require careful evaluation in other tasks.

Final Step: onnx2tf and FP16 Quantization

onnx2tf is a tool to convert ONNX models to TFLite format. FP16 (half-precision floating point) quantization is a common choice for mobile deployment. It halves the model size with acceptable accuracy loss and leverages the FP16 compute units of mobile GPUs.

Terminal window
onnx2tf -i ocr_det_v5_fixed.onnx -o converted_det \
-b 1 -ois x:1,3,640,640 -n

The -ois parameter here specifies a static input shape. Static shapes are crucial for GPU acceleration; dynamic shapes would require recompiling the GPU program for every inference, severely impacting performance.

Text Detection: Differentiable Binarization in DBNet

The detection module of PP-OCRv5 is based on DBNet (Differentiable Binarization Network) 4. Traditional text detection methods use a fixed threshold for binarization, whereas DBNet’s innovation lies in letting the network learn the optimal threshold for each pixel.

flowchart TB
subgraph DBNet["DBNet Architecture"]
direction TB
IMG[Input Image<br/>H x W x 3]
BB[Backbone<br/>MobileNetV3]
FPN[FPN Feature Pyramid<br/>Multi-scale Fusion]
subgraph Heads["Dual Branch Output"]
PH[Probability Map Branch<br/>P: H x W x 1]
TH[Threshold Map Branch<br/>T: H x W x 1]
end
DB["Differentiable Binarization<br/>B = sigmoid k * P-T"]
end
IMG --> BB --> FPN
FPN --> PH
FPN --> TH
PH --> DB
TH --> DB

Standard Binarization vs. Differentiable Binarization

Standard binarization is a step function:

Bi,j={1if Pi,jt0otherwiseB_{i,j} = \begin{cases} 1 & \text{if } P_{i,j} \geq t \\ 0 & \text{otherwise} \end{cases}

This function is non-differentiable and cannot be trained end-to-end via backpropagation. DBNet proposes an approximate function:

B^i,j=11+ek(Pi,jTi,j)\hat{B}_{i,j} = \frac{1}{1 + e^{-k(P_{i,j} - T_{i,j})}}

where PP is the probability map, TT is the threshold map (learned by the network), and kk is the amplification factor (set to 50 during training).

TIP

This formula is essentially a Sigmoid function, but with PTP - T as the input. When kk is large enough, its behavior approaches a step function while remaining differentiable.

Engineering Implementation of the Post-processing Pipeline

In the PPOCRv5-Android project, the post-processing pipeline is implemented in postprocess.cpp. The core process includes:

flowchart LR
subgraph Input["Model Output"]
PM[Probability Map P<br/>640 x 640]
end
subgraph Binary["Binarization"]
BT[Threshold Filtering<br/>threshold=0.1]
BM[Binary Map<br/>640 x 640]
end
subgraph Contour["Contour Detection"]
DS[4x Downsampling<br/>160 x 160]
CC[Connected Component Analysis<br/>BFS Traversal]
BD[Boundary Point Extraction]
end
subgraph Geometry["Geometric Calculation"]
CH[Convex Hull Calculation<br/>Graham Scan]
RR[Rotating Calipers<br/>Minimum Area Rectangle]
UC[Unclip Expansion<br/>ratio=1.5]
end
subgraph Output["Output"]
TB[RotatedRect<br/>center, size, angle]
end
PM --> BT --> BM
BM --> DS --> CC --> BD
BD --> CH --> RR --> UC --> TB

In the actual code, the TextDetector::Impl::Detect method demonstrates the complete detection process:

std::vector<RotatedRect> Detect(const uint8_t *image_data,
int width, int height, int stride,
float *detection_time_ms) {
// 1. Calculate scale ratios
scale_x_ = static_cast<float>(width) / kDetInputSize;
scale_y_ = static_cast<float>(height) / kDetInputSize;
// 2. Bilinear interpolation resize to 640x640
image_utils::ResizeBilinear(image_data, width, height, stride,
resized_buffer_.data(), kDetInputSize, kDetInputSize);
// 3. ImageNet Normalization
PrepareFloatInput();
// 4. Inference
auto run_result = compiled_model_->Run(input_buffers_, output_buffers_);
// 5. Binarization
BinarizeOutput(prob_map, total_pixels);
// 6. Contour Detection
auto contours = postprocess::FindContours(binary_map_.data(),
kDetInputSize, kDetInputSize);
// 7. Minimum Area Rectangle + Unclip
for (const auto &contour : contours) {
RotatedRect rect = postprocess::MinAreaRect(contour);
UnclipBox(rect, kUnclipRatio);
// Map coordinates back to original image
rect.center_x *= scale_x_;
rect.center_y *= scale_y_;
// ...
}
}

The key to this process is the “Minimum Area Rotated Rectangle.” Unlike axis-aligned bounding boxes, rotated rectangles can tightly fit text at any angle, which is crucial for tilted text in natural scenes.

Unclip: The Text Box Expansion Algorithm

The text regions output by DBNet are usually slightly smaller than the actual text because the network learns the “core region” of the text. To obtain the complete text boundaries, an expansion (Unclip) operation must be performed on the detected polygons.

The mathematical principle of Unclip is based on the inverse operation of the Vatti polygon clipping algorithm. Given a polygon PP and an expansion distance dd, the expanded polygon PP' satisfies:

d=A×rLd = \frac{A \times r}{L}

where AA is the polygon area, LL is the perimeter, and rr is the expansion ratio (usually set to 1.5).

In postprocess.cpp, the UnclipBox function implements this logic:

void UnclipBox(RotatedRect &box, float unclip_ratio) {
// Calculate expansion distance
float area = box.width * box.height;
float perimeter = 2.0f * (box.width + box.height);
if (perimeter < 1e-6f) return; // Prevent division by zero
// d = A * r / L
float distance = area * unclip_ratio / perimeter;
// Expand outwards: increase width and height by 2d each
box.width += 2.0f * distance;
box.height += 2.0f * distance;
}

This simplified version assumes the text box is a rectangle. For more complex polygons, a full Clipper library implementation for polygon offsetting would be required:

// Full polygon Unclip (using Clipper library)
ClipperLib::Path polygon;
for (const auto& pt : contour) {
polygon.push_back(ClipperLib::IntPoint(
static_cast<int>(pt.x * 1000), // Scale up to maintain precision
static_cast<int>(pt.y * 1000)
));
}
ClipperLib::ClipperOffset offset;
offset.AddPath(polygon, ClipperLib::jtRound, ClipperLib::etClosedPolygon);
ClipperLib::Paths solution;
offset.Execute(solution, distance * 1000); // Expand

NOTE

PPOCRv5-Android chooses simplified rectangular expansion over full polygon offsetting because:

  • Most text boxes are nearly rectangular.
  • The full Clipper library would significantly increase binary size.
  • The simplified version offers better performance.

Text Recognition: SVTRv2 and CTC Decoding

If detection is “finding where the text is,” then recognition is “reading what the text says.” The recognition module of PP-OCRv5 is based on SVTRv2 (Scene Text Recognition with Visual Transformer v2) 5.

Architectural Innovations in SVTRv2

SVTRv2 introduces three key improvements over its predecessor SVTR:

flowchart TB
subgraph SVTRv2["SVTRv2 Architecture"]
direction TB
subgraph Encoder["Visual Encoder"]
PE[Patch Embedding<br/>4x4 Conv]
subgraph Mixing["Mixing Attention Block x12"]
LA[Local Attention<br/>7x7 Window]
GA[Global Attention<br/>Global Receptive Field]
FFN[Feed Forward<br/>MLP]
end
end
subgraph Decoder["CTC Decoder"]
FC[Fully Connected Layer<br/>D -> 18384]
SM[Softmax]
CTC[CTC Decode]
end
end
PE --> LA --> GA --> FFN
FFN --> FC --> SM --> CTC
  1. Mixing Attention Mechanism: Alternates between local attention (capturing stroke details) and global attention (understanding character structure). Local attention uses a 7x7 sliding window, reducing computational complexity from O(n2)O(n^2) to O(n×49)O(n \times 49).

  2. Multi-scale Feature Fusion: Unlike the single resolution of ViT, SVTRv2 uses different feature map resolutions at different depths, similar to a CNN’s pyramid structure.

  3. Semantic Guidance Module: A lightweight semantic branch is added at the end of the encoder to help the model understand semantic relationships between characters rather than just visual features.

These improvements allow SVTRv2 to achieve accuracy comparable to attention-based methods while maintaining the simplicity of CTC decoding 6.

Why CTC instead of Attention?

There are two mainstream paradigms for text recognition:

  1. CTC (Connectionist Temporal Classification): Treats recognition as a sequence labeling problem where output is aligned with input.
  2. Attention-based Decoder: Uses an attention mechanism to generate output character by character.

Attention methods usually offer higher accuracy, but CTC methods are simpler and faster. SVTRv2’s contribution is that by improving the visual encoder, it allows CTC methods to reach or even exceed the accuracy of attention methods 6.

The core of CTC decoding is “merging duplicates” and “removing blanks”:

flowchart LR
subgraph Input["Model Output"]
L["Logits<br/>[T, 18384]"]
end
subgraph Argmax["Argmax NEON"]
A1["t=0: blank"]
A2["t=1: H"]
A3["t=2: H"]
A4["t=3: blank"]
A5["t=4: e"]
A6["t=5: l"]
A7["t=6: l"]
A8["t=7: l"]
A9["t=8: o"]
end
subgraph Merge["Merge Duplicates"]
M["blank, H, blank, e, l, o"]
end
subgraph Remove["Remove Blanks"]
R["H, e, l, o"]
end
subgraph Output["Output"]
O["Helo - Error"]
end
L --> A1 & A2 & A3 & A4 & A5 & A6 & A7 & A8 & A9
A1 & A2 & A3 & A4 & A5 & A6 & A7 & A8 & A9 --> Merge --> Remove --> Output

Wait, there’s a problem here. If the original text is “Hello,” the two ‘l’s are incorrectly merged. The CTC solution is to insert a blank token between repeated characters.

Correct Encoding: [blank, H, e, l, blank, l, o]
Decoding Result: "Hello"

NEON-Optimized CTC Decoding

CTC decoding in PPOCRv5-Android uses NEON-optimized Argmax. In text_recognizer.cpp:

inline void ArgmaxNeon8(const float *__restrict__ data, int size,
int &max_idx, float &max_val) {
if (size < 16) {
// Scalar fallback
max_idx = 0;
max_val = data[0];
for (int i = 1; i < size; ++i) {
if (data[i] > max_val) {
max_val = data[i];
max_idx = i;
}
}
return;
}
// NEON vectorization: process 4 floats at a time
float32x4_t v_max = vld1q_f32(data);
int32x4_t v_idx = {0, 1, 2, 3};
int32x4_t v_max_idx = v_idx;
const int32x4_t v_four = vdupq_n_s32(4);
int i = 4;
for (; i + 4 <= size; i += 4) {
float32x4_t v_curr = vld1q_f32(data + i);
v_idx = vaddq_s32(v_idx, v_four);
// Vectorized comparison and conditional selection
uint32x4_t cmp = vcgtq_f32(v_curr, v_max);
v_max = vbslq_f32(cmp, v_curr, v_max); // Select larger value
v_max_idx = vbslq_s32(cmp, v_idx, v_max_idx); // Select corresponding index
}
// Horizontal reduction: find the maximum among the 4 candidates
float max_vals[4];
int32_t max_idxs[4];
vst1q_f32(max_vals, v_max);
vst1q_s32(max_idxs, v_max_idx);
// ... final comparison
}

For an Argmax with 18,384 categories, NEON optimization can provide approximately a 3x speedup.

Mathematical Principles of CTC Loss and Decoding

The core idea of CTC is that given an input sequence XX and all possible alignment paths π\pi, the probability of the target sequence YY is calculated as:

P(YX)=πB1(Y)P(πX)P(Y|X) = \sum_{\pi \in \mathcal{B}^{-1}(Y)} P(\pi|X)

where B\mathcal{B} is a “many-to-one mapping function” that maps path π\pi to the output sequence YY (by merging duplicates and removing blanks).

During inference, we use Greedy Decoding instead of full Beam Search:

std::string CTCGreedyDecode(const float* logits, int time_steps, int num_classes,
const std::vector<std::string>& dictionary) {
std::string result;
int prev_idx = -1; // Used for merging duplicates
for (int t = 0; t < time_steps; ++t) {
// Find the category with the maximum probability for the current time step
int max_idx = 0;
float max_val = logits[t * num_classes];
for (int c = 1; c < num_classes; ++c) {
if (logits[t * num_classes + c] > max_val) {
max_val = logits[t * num_classes + c];
max_idx = c;
}
}
// CTC decoding rules:
// 1. Skip blank token (index 0)
// 2. Merge consecutive duplicate characters
if (max_idx != 0 && max_idx != prev_idx) {
result += dictionary[max_idx - 1]; // -1 because blank occupies index 0
}
prev_idx = max_idx;
}
return result;
}

The time complexity of greedy decoding is O(T×C)O(T \times C), where TT is the number of time steps and CC is the number of categories. For PP-OCRv5, T80T \approx 80 and C=18384C = 18384, requiring about 1.5 million comparisons per decoding. This is why NEON optimization is so important.

TIP

Beam Search can improve decoding accuracy, but its computational cost is kk times that of greedy decoding (where kk is the beam width). On mobile devices, greedy decoding is usually the better choice.

Character Dictionary: The Challenge of 18,383 Characters

PP-OCRv5 supports 18,383 characters, including:

  • Common Simplified Chinese characters
  • Common Traditional Chinese characters
  • English letters and numbers
  • Japanese Hiragana and Katakana
  • Common punctuation and special characters

This dictionary is stored in the keys_v5.txt file, one character per line. During CTC decoding, the model output logits have a shape of [1, T, 18384], where T is the number of time steps, and 18384 = 18383 characters + 1 blank token.

LiteRT C++ API: Modern Interfaces After the 2024 Refactor

PPOCRv5-Android uses the LiteRT C++ API refactored in 2024, which provides a more modern interface design. Compared to the traditional TFLite C API, the new API offers better type safety and resource management capabilities.

Comparison of Old and New APIs

The LiteRT 2024 refactor brought significant API changes:

FeatureOld API (TFLite)New API (LiteRT)
Namespacetflite::litert::
Error HandlingReturns TfLiteStatus enumReturns Expected<T> type
Memory ManagementManual managementRAII automatic management
Delegate ConfigScattered APIsUnified Options class
Tensor AccessPointers + manual castingType-safe TensorBuffer

The core advantage of the new API is type safety and automatic resource management. Taking error handling as an example:

// Old API: manual check required for every return value
TfLiteStatus status = TfLiteInterpreterAllocateTensors(interpreter);
if (status != kTfLiteOk) {
// Error handling
}
// New API: uses Expected type, supports method chaining
auto model_result = litert::CompiledModel::Create(env, model_path, options);
if (!model_result) {
LOGE(TAG, "Error: %s", model_result.Error().Message().c_str());
return false;
}
auto model = std::move(*model_result); // Automatic lifecycle management

Environment and Model Initialization

In text_detector.cpp, the initialization process is as follows:

bool Initialize(const std::string &model_path, AcceleratorType accelerator_type) {
// 1. Create LiteRT environment
auto env_result = litert::Environment::Create({});
if (!env_result) {
LOGE(TAG, "Failed to create LiteRT environment: %s",
env_result.Error().Message().c_str());
return false;
}
env_ = std::move(*env_result);
// 2. Configure hardware accelerator
auto options_result = litert::Options::Create();
auto hw_accelerator = ToLiteRtAccelerator(accelerator_type);
options.SetHardwareAccelerators(hw_accelerator);
// 3. Compile model
auto model_result = litert::CompiledModel::Create(*env_, model_path, options);
if (!model_result) {
LOGW(TAG, "Failed to create CompiledModel with accelerator %d: %s",
static_cast<int>(accelerator_type),
model_result.Error().Message().c_str());
return false;
}
compiled_model_ = std::move(*model_result);
// 4. Resize input tensor shape
std::vector<int> input_dims = {1, kDetInputSize, kDetInputSize, 3};
compiled_model_->ResizeInputTensor(0, absl::MakeConstSpan(input_dims));
// 5. Create Managed Buffers
CreateBuffersWithCApi();
return true;
}

Managed Tensor Buffer: The Key to Zero-Copy Inference

LiteRT’s Managed Tensor Buffer is key to achieving high-performance inference. It allows the GPU Delegate to access the buffer directly, eliminating CPU-GPU data transfers:

bool CreateBuffersWithCApi() {
LiteRtCompiledModel c_model = compiled_model_->Get();
LiteRtEnvironment c_env = env_->Get();
// Get input buffer requirements
LiteRtTensorBufferRequirements input_requirements = nullptr;
LiteRtGetCompiledModelInputBufferRequirements(
c_model, /*signature_index=*/0, /*input_index=*/0,
&input_requirements);
// Get tensor type information
auto input_type = compiled_model_->GetInputTensorType(0, 0);
LiteRtRankedTensorType tensor_type =
static_cast<LiteRtRankedTensorType>(*input_type);
// Create managed buffer
LiteRtTensorBuffer input_buffer = nullptr;
LiteRtCreateManagedTensorBufferFromRequirements(
c_env, &tensor_type, input_requirements, &input_buffer);
// Wrap as C++ object for automatic lifecycle management
input_buffers_.push_back(
litert::TensorBuffer::WrapCObject(input_buffer,
litert::OwnHandle::kYes));
return true;
}

The advantages of this design are:

  1. Zero-copy inference: The GPU Delegate can access the buffer directly without CPU-GPU data transfer.
  2. Automatic memory management: OwnHandle::kYes ensures the buffer is automatically released when the C++ object is destroyed.
  3. Type safety: Tensor type matching is checked at compile time.

GPU Acceleration: Choosing OpenCL and the Trade-offs

LiteRT provides several hardware acceleration options:

flowchart TB
subgraph Delegates["LiteRT Delegate Ecosystem"]
direction TB
GPU_CL[GPU Delegate<br/>OpenCL Backend]
GPU_GL[GPU Delegate<br/>OpenGL ES Backend]
NNAPI[NNAPI Delegate<br/>Android HAL]
XNN[XNNPACK Delegate<br/>CPU Optimized]
end
subgraph Hardware["Hardware Mapping"]
direction TB
ADRENO[Adreno GPU<br/>Qualcomm]
MALI[Mali GPU<br/>ARM]
NPU[NPU/DSP<br/>Vendor Specific]
CPU[ARM CPU<br/>NEON]
end
GPU_CL --> ADRENO
GPU_CL --> MALI
GPU_GL --> ADRENO
GPU_GL --> MALI
NNAPI --> NPU
XNN --> CPU
AcceleratorBackendProsCons
GPUOpenCLWide support, good performanceNot a standard Android component
GPUOpenGL ESStandard Android componentPerformance inferior to OpenCL
NPUNNAPIHighest performancePoor device compatibility
CPUXNNPACKWidest compatibilityLowest performance

PPOCRv5-Android chooses OpenCL as the primary acceleration backend. Google released the OpenCL backend for TFLite in 2020, which achieved about a 2x speedup on Adreno GPUs compared to the OpenGL ES backend 7.

The advantages of OpenCL come from several aspects:

  1. Design intent: OpenCL was designed for general-purpose computing from the start, whereas OpenGL is a graphics rendering API that only later added support for compute shaders.
  2. Constant memory: OpenCL’s constant memory is highly efficient for accessing neural network weights.
  3. FP16 support: OpenCL natively supports half-precision floating point, whereas OpenGL support came later.

However, OpenCL has a fatal flaw: it is not a standard Android component. OpenCL implementations vary in quality across vendors, and some devices do not support it at all.

OpenCL vs. OpenGL ES: Deep Performance Comparison

To understand OpenCL’s advantage, we need to dive into GPU architecture. Taking the Qualcomm Adreno 640 as an example:

flowchart TB
subgraph Adreno["Adreno 640 Architecture"]
direction TB
subgraph SP["Shader Processors x2"]
ALU1[ALU Array<br/>256 FP32 / 512 FP16]
ALU2[ALU Array<br/>256 FP32 / 512 FP16]
end
subgraph Memory["Memory Hierarchy"]
L1[L1 Cache<br/>16KB per SP]
L2[L2 Cache<br/>1MB Shared]
GMEM[Global Memory<br/>LPDDR4X]
end
subgraph Special["Special Units"]
TMU[Texture Unit<br/>Bilinear Interpolation]
CONST[Constant Cache<br/>Weight Acceleration]
end
end
ALU1 --> L1
ALU2 --> L1
L1 --> L2 --> GMEM
TMU --> L1
CONST --> ALU1 & ALU2

OpenCL’s performance advantage stems from:

FeatureOpenCLOpenGL ES Compute
Constant MemoryNative support, hardware acceleratedEmulated via UBO
Workgroup SizeFlexibly configuredLimited by shader model
Memory BarriersFine-grained controlCoarse-grained
FP16 Computecl_khr_fp16 extensionRequires mediump precision
Debugging ToolsSnapdragon ProfilerLimited support

In convolution operations, weights are typically constant. OpenCL can place weights in constant memory, benefiting from hardware-level broadcast optimizations. OpenGL ES, on the other hand, needs to pass weights as Uniform Buffer Objects (UBOs), increasing memory access overhead.

NOTE

Since Android 7.0, Google has restricted apps from directly loading OpenCL libraries. However, LiteRT’s GPU Delegate bypasses this restriction by dynamically loading the system’s OpenCL implementation via dlopen. This is why the GPU Delegate needs to detect OpenCL availability at runtime.

Graceful Fallback Strategy

PPOCRv5-Android implements a graceful fallback strategy:

ocr_engine.cpp
constexpr AcceleratorType kFallbackChain[] = {
AcceleratorType::kGpu, // Preferred: GPU
AcceleratorType::kCpu, // Fallback: CPU
};
std::unique_ptr<OcrEngine> OcrEngine::Create(
const std::string &det_model_path,
const std::string &rec_model_path,
const std::string &keys_path,
AcceleratorType accelerator_type) {
auto engine = std::unique_ptr<OcrEngine>(new OcrEngine());
int start_index = GetFallbackStartIndex(accelerator_type);
for (int i = start_index; i < kFallbackChainSize; ++i) {
AcceleratorType current = kFallbackChain[i];
auto detector = TextDetector::Create(det_model_path, current);
if (!detector) continue;
auto recognizer = TextRecognizer::Create(rec_model_path, keys_path, current);
if (!recognizer) continue;
engine->detector_ = std::move(detector);
engine->recognizer_ = std::move(recognizer);
engine->active_accelerator_ = current;
engine->WarmUp();
return engine;
}
return nullptr;
}

This strategy ensures the app can run on any device, albeit with varying performance.

Native Layer: C++ and NEON Optimization

Why use C++ instead of Kotlin?

The answer is simple: performance. Image preprocessing involves massive pixel-level operations, and the overhead of these operations on the JVM is unacceptable. More importantly, C++ can directly use ARM NEON SIMD instructions to achieve vectorized computation.

NEON: ARM’s SIMD Instruction Set

NEON is the SIMD (Single Instruction, Multiple Data) extension for ARM processors. It allows a single instruction to process multiple data elements simultaneously.

flowchart LR
subgraph NEON["128-bit NEON Register"]
direction TB
F4["4x float32"]
I8["8x int16"]
B16["16x int8"]
end
subgraph Operations["Vectorized Operations"]
direction TB
LD["vld1q_f32<br/>Load 4 floats"]
SUB["vsubq_f32<br/>4-way parallel subtraction"]
MUL["vmulq_f32<br/>4-way parallel multiplication"]
ST["vst1q_f32<br/>Store 4 floats"]
end
subgraph Speedup["Performance Boost"]
S1["Scalar: 4 instructions"]
S2["NEON: 1 instruction"]
S3["Theoretical Speedup: 4x"]
end
F4 --> LD
LD --> SUB --> MUL --> ST
ST --> S3

PPOCRv5-Android uses NEON optimization in several critical paths. Taking binarization as an example (text_detector.cpp):

void BinarizeOutput(const float *prob_map, int total_pixels) {
#if defined(__ARM_NEON) || defined(__ARM_NEON__)
const float32x4_t v_threshold = vdupq_n_f32(kBinaryThreshold);
const uint8x16_t v_255 = vdupq_n_u8(255);
const uint8x16_t v_0 = vdupq_n_u8(0);
int i = 0;
for (; i + 16 <= total_pixels; i += 16) {
// Process 16 pixels at a time
float32x4_t f0 = vld1q_f32(prob_map + i);
float32x4_t f1 = vld1q_f32(prob_map + i + 4);
float32x4_t f2 = vld1q_f32(prob_map + i + 8);
float32x4_t f3 = vld1q_f32(prob_map + i + 12);
// Vectorized comparison
uint32x4_t cmp0 = vcgtq_f32(f0, v_threshold);
uint32x4_t cmp1 = vcgtq_f32(f1, v_threshold);
uint32x4_t cmp2 = vcgtq_f32(f2, v_threshold);
uint32x4_t cmp3 = vcgtq_f32(f3, v_threshold);
// Narrow down to uint8
uint16x4_t n0 = vmovn_u32(cmp0);
uint16x4_t n1 = vmovn_u32(cmp1);
uint16x8_t n01 = vcombine_u16(n0, n1);
// ... merge and store
}
// Scalar fallback for remaining pixels
for (; i < total_pixels; ++i) {
binary_map_[i] = (prob_map[i] > kBinaryThreshold) ? 255 : 0;
}
#else
// Pure scalar implementation
for (int i = 0; i < total_pixels; ++i) {
binary_map_[i] = (prob_map[i] > kBinaryThreshold) ? 255 : 0;
}
#endif
}

Key optimization points in this code:

  • Batch loading: vld1q_f32 loads 4 floats at once, reducing memory access frequency.
  • Vectorized comparison: vcgtq_f32 compares 4 values simultaneously to generate a mask.
  • Type narrowing: vmovn_u32 compresses 32-bit results into 16-bit, and eventually to 8-bit.

Compared to a scalar implementation, NEON optimization can provide a 3-4x speedup 8.

NEON Implementation of ImageNet Normalization

Image normalization is a crucial step in preprocessing. ImageNet standardization uses the following formula:

xnormalized=xμσx_{normalized} = \frac{x - \mu}{\sigma}

where μ=[0.485,0.456,0.406]\mu = [0.485, 0.456, 0.406] and σ=[0.229,0.224,0.225]\sigma = [0.229, 0.224, 0.225] (RGB channels).

In image_utils.cpp, the NEON-optimized normalization is implemented as follows:

void NormalizeImageNet(const uint8_t* src, int width, int height, int stride,
float* dst) {
// ImageNet normalization parameters
constexpr float kMeanR = 0.485f, kMeanG = 0.456f, kMeanB = 0.406f;
constexpr float kStdR = 0.229f, kStdG = 0.224f, kStdB = 0.225f;
constexpr float kInvStdR = 1.0f / kStdR;
constexpr float kInvStdG = 1.0f / kStdG;
constexpr float kInvStdB = 1.0f / kStdB;
constexpr float kScale = 1.0f / 255.0f;
#if defined(__ARM_NEON) || defined(__ARM_NEON__)
// Precompute: (1/255) / std = 1 / (255 * std)
const float32x4_t v_scale_r = vdupq_n_f32(kScale * kInvStdR);
const float32x4_t v_scale_g = vdupq_n_f32(kScale * kInvStdG);
const float32x4_t v_scale_b = vdupq_n_f32(kScale * kInvStdB);
// Precompute: -mean / std
const float32x4_t v_bias_r = vdupq_n_f32(-kMeanR * kInvStdR);
const float32x4_t v_bias_g = vdupq_n_f32(-kMeanG * kInvStdG);
const float32x4_t v_bias_b = vdupq_n_f32(-kMeanB * kInvStdB);
for (int y = 0; y < height; ++y) {
const uint8_t* row = src + y * stride;
float* dst_row = dst + y * width * 3;
int x = 0;
for (; x + 4 <= width; x += 4) {
// Load 4 RGBA pixels (16 bytes)
uint8x16_t rgba = vld1q_u8(row + x * 4);
// De-interleave: RGBARGBARGBARGBA -> RRRR, GGGG, BBBB, AAAA
uint8x16x4_t channels = vld4q_u8(row + x * 4);
// uint8 -> uint16 -> uint32 -> float32
uint16x8_t r16 = vmovl_u8(vget_low_u8(channels.val[0]));
uint16x8_t g16 = vmovl_u8(vget_low_u8(channels.val[1]));
uint16x8_t b16 = vmovl_u8(vget_low_u8(channels.val[2]));
float32x4_t r_f = vcvtq_f32_u32(vmovl_u16(vget_low_u16(r16)));
float32x4_t g_f = vcvtq_f32_u32(vmovl_u16(vget_low_u16(g16)));
float32x4_t b_f = vcvtq_f32_u32(vmovl_u16(vget_low_u16(b16)));
// Normalize: (x / 255 - mean) / std = x * (1/255/std) + (-mean/std)
r_f = vmlaq_f32(v_bias_r, r_f, v_scale_r); // fused multiply-add
g_f = vmlaq_f32(v_bias_g, g_f, v_scale_g);
b_f = vmlaq_f32(v_bias_b, b_f, v_scale_b);
// Interleaved store: RRRR, GGGG, BBBB -> RGBRGBRGBRGB
float32x4x3_t rgb = {r_f, g_f, b_f};
vst3q_f32(dst_row + x * 3, rgb);
}
// Scalar processing for remaining pixels
for (; x < width; ++x) {
const uint8_t* px = row + x * 4;
float* dst_px = dst_row + x * 3;
dst_px[0] = (px[0] * kScale - kMeanR) * kInvStdR;
dst_px[1] = (px[1] * kScale - kMeanG) * kInvStdG;
dst_px[2] = (px[2] * kScale - kMeanB) * kInvStdB;
}
}
#else
// Scalar implementation (omitted)
#endif
}

Key optimization techniques in this code:

  1. Precomputing constants: Transforming (x - mean) / std into x * scale + bias to reduce runtime division.
  2. Fused Multiply-Add: vmlaq_f32 performs multiplication and addition in a single instruction.
  3. De-interleaved loading: vld4q_u8 automatically separates RGBA into four channels.
  4. Interleaved storing: vst3q_f32 writes RGB channels back to memory in an interleaved manner.

Zero OpenCV Dependency

Many OCR projects rely on OpenCV for image preprocessing. While OpenCV is powerful, it brings a massive binary footprint; the OpenCV library on Android usually exceeds 10MB.

PPOCRv5-Android chooses a “Zero OpenCV Dependency” route. All image preprocessing operations are implemented in pure C++ in image_utils.cpp:

  • Bilinear interpolation resize: Hand-written implementation with NEON support.
  • Normalization: ImageNet standardization and recognition standardization.
  • Perspective Transform: Cropping text regions at any angle from the original image.

NEON Implementation of Bilinear Interpolation

Bilinear interpolation is the core algorithm for image scaling. Given source image coordinates (x,y)(x, y), bilinear interpolation calculates the target pixel value:

f(x,y)=(1α)(1β)f00+α(1β)f10+(1α)βf01+αβf11f(x, y) = (1-\alpha)(1-\beta)f_{00} + \alpha(1-\beta)f_{10} + (1-\alpha)\beta f_{01} + \alpha\beta f_{11}

where α=xx\alpha = x - \lfloor x \rfloor, β=yy\beta = y - \lfloor y \rfloor, and fijf_{ij} are the values of the four neighboring pixels.

void ResizeBilinear(const uint8_t* src, int src_w, int src_h, int src_stride,
uint8_t* dst, int dst_w, int dst_h) {
const float scale_x = static_cast<float>(src_w) / dst_w;
const float scale_y = static_cast<float>(src_h) / dst_h;
for (int dy = 0; dy < dst_h; ++dy) {
const float sy = (dy + 0.5f) * scale_y - 0.5f;
const int y0 = std::max(0, static_cast<int>(std::floor(sy)));
const int y1 = std::min(src_h - 1, y0 + 1);
const float beta = sy - y0;
const float inv_beta = 1.0f - beta;
const uint8_t* row0 = src + y0 * src_stride;
const uint8_t* row1 = src + y1 * src_stride;
uint8_t* dst_row = dst + dy * dst_w * 4;
#if defined(__ARM_NEON) || defined(__ARM_NEON__)
// NEON: Process 4 target pixels at a time
const float32x4_t v_beta = vdupq_n_f32(beta);
const float32x4_t v_inv_beta = vdupq_n_f32(inv_beta);
int dx = 0;
for (; dx + 4 <= dst_w; dx += 4) {
// Calculate 4 source coordinates
float sx[4];
for (int i = 0; i < 4; ++i) {
sx[i] = ((dx + i) + 0.5f) * scale_x - 0.5f;
}
// Load alpha weights
float alpha[4], inv_alpha[4];
int x0[4], x1[4];
for (int i = 0; i < 4; ++i) {
x0[i] = std::max(0, static_cast<int>(std::floor(sx[i])));
x1[i] = std::min(src_w - 1, x0[i] + 1);
alpha[i] = sx[i] - x0[i];
inv_alpha[i] = 1.0f - alpha[i];
}
// Perform bilinear interpolation for each channel
for (int c = 0; c < 4; ++c) { // RGBA
float32x4_t f00, f10, f01, f11;
// Gather neighboring values for 4 pixels
f00 = vsetq_lane_f32(row0[x0[0] * 4 + c], f00, 0);
f00 = vsetq_lane_f32(row0[x0[1] * 4 + c], f00, 1);
f00 = vsetq_lane_f32(row0[x0[2] * 4 + c], f00, 2);
f00 = vsetq_lane_f32(row0[x0[3] * 4 + c], f00, 3);
// ... f10, f01, f11 similar
// Bilinear interpolation formula
float32x4_t v_alpha = vld1q_f32(alpha);
float32x4_t v_inv_alpha = vld1q_f32(inv_alpha);
float32x4_t top = vmlaq_f32(
vmulq_f32(f00, v_inv_alpha),
f10, v_alpha
);
float32x4_t bottom = vmlaq_f32(
vmulq_f32(f01, v_inv_alpha),
f11, v_alpha
);
float32x4_t result = vmlaq_f32(
vmulq_f32(top, v_inv_beta),
bottom, v_beta
);
// Convert back to uint8 and store
uint32x4_t result_u32 = vcvtq_u32_f32(result);
// ... store
}
}
#endif
// Scalar processing for remaining pixels (omitted)
}
}

TIP

NEON optimization for bilinear interpolation is complex because the addresses of the four neighboring pixels are non-contiguous. A more efficient method is to use separable bilinear interpolation: interpolate horizontally first, then vertically. This better utilizes cache locality.

The cost of this choice is more development work, but the benefits are significant:

  1. APK size reduced by about 10MB.
  2. Full control over preprocessing logic, facilitating optimization.
  3. Avoidance of OpenCV version compatibility issues.

Perspective Transform: From Rotated Rectangles to Standard Text Lines

Text recognition models expect horizontal text line images as input. However, detected text boxes can be rotated rectangles at any angle. Perspective transform is responsible for “straightening” these rotated rectangular regions.

In text_recognizer.cpp, the CropAndRotate method implements this functionality:

void CropAndRotate(const uint8_t *__restrict__ image_data,
int width, int height, int stride,
const RotatedRect &box, int &target_width) {
// Calculate the four corner points of the rotated rectangle
const float cos_angle = std::cos(box.angle * M_PI / 180.0f);
const float sin_angle = std::sin(box.angle * M_PI / 180.0f);
const float half_w = box.width / 2.0f;
const float half_h = box.height / 2.0f;
float corners[8]; // (x, y) coordinates for 4 corners
corners[0] = box.center_x + (-half_w * cos_angle - (-half_h) * sin_angle);
corners[1] = box.center_y + (-half_w * sin_angle + (-half_h) * cos_angle);
// ... calculate other corners
// Adaptive target width: maintain aspect ratio
const float aspect_ratio = src_width / std::max(src_height, 1.0f);
target_width = static_cast<int>(kRecInputHeight * aspect_ratio);
target_width = std::clamp(target_width, 1, kRecInputWidth); // 48x[1, 320]
// Affine transform matrix
const float a00 = (x1 - x0) * inv_dst_w;
const float a01 = (x3 - x0) * inv_dst_h;
const float a10 = (y1 - y0) * inv_dst_w;
const float a11 = (y3 - y0) * inv_dst_h;
// Bilinear sampling + normalization (NEON optimized)
for (int dy = 0; dy < kRecInputHeight; ++dy) {
for (int dx = 0; dx < target_width; ++dx) {
float sx = base_sx + a00 * dx;
float sy = base_sy + a10 * dx;
BilinearSampleNeon(image_data, stride, sx, sy, dst_row + dx * 3);
}
}
}

Key optimizations in this implementation:

  1. Adaptive width: Dynamically adjusts output width based on the text box aspect ratio, avoiding excessive stretching or compression.
  2. Affine transform approximation: For text boxes that are approximately parallelograms, affine transform is used instead of perspective transform to reduce computation.
  3. NEON Bilinear Sampling: Sampling and normalization are completed in a single pass, reducing memory access.

JNI: The Bridge Between Kotlin and C++

JNI (Java Native Interface) is the bridge for communication between Kotlin/Java and C++. However, JNI calls have overhead, and frequent cross-language calls can severely impact performance.

The design principle of PPOCRv5-Android is to minimize the number of JNI calls. The entire OCR process requires only one JNI call:

sequenceDiagram
participant K as Kotlin Layer
participant J as JNI Bridge
participant N as Native Layer
participant G as GPU
K->>J: process(bitmap)
J->>N: Pass RGBA pointer
Note over N,G: Native layer completes all work
N->>N: Image Preprocessing NEON
N->>G: Text Detection Inference
G-->>N: Probability Map
N->>N: Post-processing Contour Detection
loop Each Text Box
N->>N: Perspective Transform Crop
N->>G: Text Recognition Inference
G-->>N: Logits
N->>N: CTC Decoding
end
N-->>J: OCR Results
J-->>K: List OcrResult

In ppocrv5_jni.cpp, the core nativeProcess function demonstrates this design:

JNIEXPORT jobjectArray JNICALL
Java_me_fleey_ppocrv5_ocr_OcrEngine_nativeProcess(
JNIEnv *env, jobject thiz, jlong handle, jobject bitmap) {
auto *engine = reinterpret_cast<ppocrv5::OcrEngine *>(handle);
// Lock Bitmap pixels
void *pixels = nullptr;
AndroidBitmap_lockPixels(env, bitmap, &pixels);
// Complete all OCR work in a single JNI call
auto results = engine->Process(
static_cast<const uint8_t *>(pixels),
static_cast<int>(bitmap_info.width),
static_cast<int>(bitmap_info.height),
static_cast<int>(bitmap_info.stride));
AndroidBitmap_unlockPixels(env, bitmap);
// Construct Java object array to return
// ...
}

This design avoids the overhead of passing data back and forth between detection and recognition.

Architecture Design: Modularity and Testability

The architecture of PPOCRv5-Android follows the “Separation of Concerns” principle:

flowchart TB
subgraph UI["Jetpack Compose UI Layer"]
direction LR
CP[CameraPreview]
GP[GalleryPicker]
RO[ResultOverlay]
end
subgraph VM["ViewModel Layer"]
OVM[OCRViewModel<br/>State Management]
end
subgraph Native["Native Layer - C++"]
OE[OcrEngine<br/>Orchestration]
subgraph Detection["Text Detection"]
TD[TextDetector]
DB[DBNet FP16]
end
subgraph Recognition["Text Recognition"]
TR[TextRecognizer]
SVTR[SVTRv2 + CTC]
end
subgraph Preprocessing["Image Processing"]
IP[ImagePreprocessor<br/>NEON Optimized]
PP[PostProcessor<br/>Contour Detection]
end
subgraph Runtime["LiteRT Runtime"]
GPU[GPU Delegate<br/>OpenCL]
CPU[CPU Fallback<br/>XNNPACK]
end
end
CP --> OVM
GP --> OVM
OVM --> RO
OVM <-->|JNI| OE
OE --> TD
OE --> TR
TD --> DB
TR --> SVTR
TD --> IP
TR --> IP
DB --> PP
DB --> GPU
SVTR --> GPU
GPU -.->|Fallback| CPU

The benefits of this layered architecture are:

  1. UI Layer: Pure Kotlin/Compose, focusing on user interaction.
  2. ViewModel Layer: Manages state and business logic.
  3. Native Layer: High-performance computing, completely decoupled from the UI.

Each layer can be tested independently. The Native layer can be unit-tested with Google Test, and the ViewModel layer can be tested with JUnit + MockK.

Kotlin Layer Encapsulation

In OcrEngine.kt, the Kotlin layer provides a clean API:

class OcrEngine private constructor(
private var nativeHandle: Long,
) : Closeable {
companion object {
init {
System.loadLibrary("ppocrv5_jni")
}
fun create(
context: Context,
acceleratorType: AcceleratorType = AcceleratorType.GPU,
): Result<OcrEngine> = runCatching {
initializeCache(context)
val detModelPath = copyAssetToCache(context, "$MODELS_DIR/$DET_MODEL_FILE")
val recModelPath = copyAssetToCache(context, "$MODELS_DIR/$REC_MODEL_FILE")
val keysPath = copyAssetToCache(context, "$MODELS_DIR/$KEYS_FILE")
val handle = OcrEngine(0).nativeCreate(
detModelPath, recModelPath, keysPath,
acceleratorType.value,
)
if (handle == 0L) {
throw OcrException("Failed to create native OCR engine")
}
OcrEngine(handle)
}
}
fun process(bitmap: Bitmap): List<OcrResult> {
check(nativeHandle != 0L) { "OcrEngine has been closed" }
return nativeProcess(nativeHandle, bitmap)?.toList() ?: emptyList()
}
override fun close() {
if (nativeHandle != 0L) {
nativeDestroy(nativeHandle)
nativeHandle = 0
}
}
}

Advantages of this design:

  1. Uses the Result type to handle initialization errors.
  2. Implements the Closeable interface, supporting use blocks for automatic resource release.
  3. Model files are automatically copied from assets to the cache directory.

Cold Start Optimization

The first inference (cold start) is usually much slower than subsequent inferences (warm start). This is because:

  1. The GPU Delegate needs to compile OpenCL programs.
  2. Model weights need to be transferred from CPU memory to GPU memory.
  3. Various caches need to be warmed up.

PPOCRv5-Android mitigates cold start issues through a Warm-up mechanism:

void OcrEngine::WarmUp() {
LOGD(TAG, "Starting warm-up (%d iterations)...", kWarmupIterations);
// Create a small test image
std::vector<uint8_t> dummy_image(kWarmupImageSize * kWarmupImageSize * 4, 128);
for (int i = 0; i < kWarmupImageSize * kWarmupImageSize; ++i) {
dummy_image[i * 4 + 0] = static_cast<uint8_t>((i * 7) % 256);
dummy_image[i * 4 + 1] = static_cast<uint8_t>((i * 11) % 256);
dummy_image[i * 4 + 2] = static_cast<uint8_t>((i * 13) % 256);
dummy_image[i * 4 + 3] = 255;
}
// Perform a few inferences to warm up
for (int iter = 0; iter < kWarmupIterations; ++iter) {
float detection_time_ms = 0.0f;
detector_->Detect(dummy_image.data(), kWarmupImageSize, kWarmupImageSize,
kWarmupImageSize * 4, &detection_time_ms);
}
LOGD(TAG, "Warm-up completed (accelerator: %s)", AcceleratorName(active_accelerator_));
}

Memory Alignment Optimization

In TextDetector::Impl, all pre-allocated buffers use 64-byte alignment:

// Pre-allocated buffers with cache-line alignment
alignas(64) std::vector<uint8_t> resized_buffer_;
alignas(64) std::vector<float> normalized_buffer_;
alignas(64) std::vector<uint8_t> binary_map_;
alignas(64) std::vector<float> prob_map_;

64-byte alignment corresponds to the cache line size of modern ARM processors. Aligned memory access avoids cache line splits and improves memory access efficiency.

Memory Pooling and Object Reuse

Frequent memory allocation and deallocation are performance killers. PPOCRv5-Android uses a pre-allocation strategy, allocating all required memory at once during initialization:

class TextDetector::Impl {
// Pre-allocated buffers, lifecycle tied to Impl
alignas(64) std::vector<uint8_t> resized_buffer_; // 640 * 640 * 4 = 1.6MB
alignas(64) std::vector<float> normalized_buffer_; // 640 * 640 * 3 * 4 = 4.9MB
alignas(64) std::vector<uint8_t> binary_map_; // 640 * 640 = 0.4MB
alignas(64) std::vector<float> prob_map_; // 640 * 640 * 4 = 1.6MB
bool Initialize(...) {
// Allocate once to avoid runtime malloc
resized_buffer_.resize(kDetInputSize * kDetInputSize * 4);
normalized_buffer_.resize(kDetInputSize * kDetInputSize * 3);
binary_map_.resize(kDetInputSize * kDetInputSize);
prob_map_.resize(kDetInputSize * kDetInputSize);
return true;
}
};

Benefits of this design:

  1. Avoids memory fragmentation: All large memory blocks are allocated at startup, preventing fragmentation during runtime.
  2. Reduces system calls: malloc can trigger system calls; pre-allocation avoids this overhead.
  3. Cache-friendly: Consecutively allocated memory is more likely to be physically contiguous, improving cache hit rates.

Branch Prediction Optimization

Modern CPUs use branch prediction to improve pipeline efficiency. Incorrect branch prediction leads to pipeline flushes, costing 10-20 clock cycles.

On hot paths, we use __builtin_expect to hint the compiler:

// Most pixels will not exceed the threshold
if (__builtin_expect(prob_map[i] > kBinaryThreshold, 0)) {
binary_map_[i] = 255;
} else {
binary_map_[i] = 0;
}

__builtin_expect(expr, val) tells the compiler that the value of expr is very likely to be val. The compiler adjusts the code layout accordingly, placing “unlikely” branches away from the main path.

Loop Unrolling and Software Pipelining

For compute-intensive loops, manual unrolling can reduce loop overhead and expose more instruction-level parallelism:

// Non-unrolled version
for (int i = 0; i < n; ++i) {
dst[i] = src[i] * scale + bias;
}
// 4x unrolled version
int i = 0;
for (; i + 4 <= n; i += 4) {
dst[i + 0] = src[i + 0] * scale + bias;
dst[i + 1] = src[i + 1] * scale + bias;
dst[i + 2] = src[i + 2] * scale + bias;
dst[i + 3] = src[i + 3] * scale + bias;
}
for (; i < n; ++i) {
dst[i] = src[i] * scale + bias;
}

After unrolling, the CPU can execute multiple independent multiply-add instructions simultaneously, fully utilizing the multiple execution units of superscalar architectures.

Prefetch Optimization

In the inner loop of the perspective transform, use __builtin_prefetch to load data for the next line in advance:

for (int dy = 0; dy < kRecInputHeight; ++dy) {
// Prefetch next line data
if (dy + 1 < kRecInputHeight) {
const float next_sy = y0 + a11 * (dy + 1);
const int next_y = static_cast<int>(next_sy);
if (next_y >= 0 && next_y < height) {
__builtin_prefetch(image_data + next_y * stride, 0, 1);
}
}
// ... process current line
}

This optimization can hide memory latency; while processing the current line, the data for the next line is already in the L1 cache.

Engineering Details of Post-processing

Connected Component Analysis and Contour Detection

In postprocess.cpp, the FindContours function implements efficient connected component analysis:

std::vector<std::vector<Point>> FindContours(const uint8_t *binary_map,
int width, int height) {
// 1. 4x downsampling to reduce computation
int ds_width = (width + kDownsampleFactor - 1) / kDownsampleFactor;
int ds_height = (height + kDownsampleFactor - 1) / kDownsampleFactor;
std::vector<uint8_t> ds_map(ds_width * ds_height);
downsample_binary_map(binary_map, width, height,
ds_map.data(), ds_width, ds_height, kDownsampleFactor);
// 2. BFS traversal of connected components
std::vector<int> labels(ds_width * ds_height, 0);
int current_label = 0;
for (int y = 0; y < ds_height; ++y) {
for (int x = 0; x < ds_width; ++x) {
if (pixel_at(ds_map.data(), x, y, ds_width) > 0 &&
labels[y * ds_width + x] == 0) {
current_label++;
std::vector<Point> boundary;
std::queue<std::pair<int, int>> queue;
queue.push({x, y});
while (!queue.empty()) {
auto [cx, cy] = queue.front();
queue.pop();
// Detect boundary pixels
if (is_boundary_pixel(ds_map.data(), cx, cy, ds_width, ds_height)) {
boundary.push_back({
static_cast<float>(cx * kDownsampleFactor + kDownsampleFactor / 2),
static_cast<float>(cy * kDownsampleFactor + kDownsampleFactor / 2)
});
}
// 4-neighbor expansion
for (int d = 0; d < 4; ++d) {
int nx = cx + kNeighborDx4[d];
int ny = cy + kNeighborDy4[d];
// ...
}
}
if (boundary.size() >= 4) {
contours.push_back(std::move(boundary));
}
}
}
}
return contours;
}

Key optimization points:

  1. 4x Downsampling: Downsampling the 640x640 binary map to 160x160 reduces computation by 16 times.
  2. Boundary Detection: Only boundary pixels are kept, rather than the entire connected component.
  3. Maximum Contour Limit: kMaxContours = 100 to prevent performance issues in extreme cases.

Convex Hull and Rotating Calipers Algorithms

Calculating the minimum area rotated rectangle involves two steps: first calculating the convex hull, then using the rotating calipers algorithm to find the minimum area bounding rectangle.

Graham Scan Convex Hull Algorithm

Graham Scan is a classic algorithm for calculating the convex hull with a time complexity of O(nlogn)O(n \log n):

std::vector<Point> ConvexHull(std::vector<Point> points) {
if (points.size() < 3) return points;
// 1. Find the bottom-most point (min y, then min x)
auto pivot = std::min_element(points.begin(), points.end(),
[](const Point& a, const Point& b) {
return a.y < b.y || (a.y == b.y && a.x < b.x);
});
std::swap(points[0], *pivot);
Point p0 = points[0];
// 2. Sort by polar angle
std::sort(points.begin() + 1, points.end(),
[&p0](const Point& a, const Point& b) {
float cross = CrossProduct(p0, a, b);
if (std::abs(cross) < 1e-6f) {
// When collinear, the closer point comes first
return DistanceSquared(p0, a) < DistanceSquared(p0, b);
}
return cross > 0; // Counter-clockwise direction
});
// 3. Build the convex hull
std::vector<Point> hull;
for (const auto& p : points) {
// Remove points that cause a clockwise turn
while (hull.size() > 1 &&
CrossProduct(hull[hull.size()-2], hull[hull.size()-1], p) <= 0) {
hull.pop_back();
}
hull.push_back(p);
}
return hull;
}
// Cross product: determine turn direction
float CrossProduct(const Point& o, const Point& a, const Point& b) {
return (a.x - o.x) * (b.y - o.y) - (a.y - o.y) * (b.x - o.x);
}

Rotating Calipers Algorithm

The Rotating Calipers algorithm iterates through each edge of the convex hull and calculates the area of the bounding rectangle based on that edge:

RotatedRect MinAreaRect(const std::vector<Point>& hull) {
if (hull.size() < 3) return {};
float min_area = std::numeric_limits<float>::max();
RotatedRect best_rect;
int n = hull.size();
int right = 1, top = 1, left = 1; // Three "caliper" positions
for (int i = 0; i < n; ++i) {
int j = (i + 1) % n;
// Direction vector of the current edge
float edge_x = hull[j].x - hull[i].x;
float edge_y = hull[j].y - hull[i].y;
float edge_len = std::sqrt(edge_x * edge_x + edge_y * edge_y);
// Unit vector
float ux = edge_x / edge_len;
float uy = edge_y / edge_len;
// Perpendicular direction
float vx = -uy;
float vy = ux;
// Find the rightmost point (max projection along edge direction)
while (Dot(hull[(right + 1) % n], ux, uy) > Dot(hull[right], ux, uy)) {
right = (right + 1) % n;
}
// Find the topmost point (max projection along perpendicular direction)
while (Dot(hull[(top + 1) % n], vx, vy) > Dot(hull[top], vx, vy)) {
top = (top + 1) % n;
}
// Find the leftmost point
while (Dot(hull[(left + 1) % n], ux, uy) < Dot(hull[left], ux, uy)) {
left = (left + 1) % n;
}
// Calculate rectangle dimensions
float width = Dot(hull[right], ux, uy) - Dot(hull[left], ux, uy);
float height = Dot(hull[top], vx, vy) - Dot(hull[i], vx, vy);
float area = width * height;
if (area < min_area) {
min_area = area;
// Update optimal rectangle parameters
best_rect.width = width;
best_rect.height = height;
best_rect.angle = std::atan2(uy, ux) * 180.0f / M_PI;
// Calculate center point...
}
}
return best_rect;
}

The key insight of rotating calipers is that as the base edge rotates, the three “calipers” (rightmost, topmost, leftmost points) only move monotonically forward. Thus, the total time complexity is O(n)O(n) rather than O(n2)O(n^2).

Minimum Area Rotated Rectangle

The MinAreaRect function uses the rotating calipers algorithm to calculate the minimum area rotated rectangle:

RotatedRect MinAreaRect(const std::vector<Point> &contour) {
// 1. Subsampling to reduce point count
std::vector<Point> points = subsample_points(contour, kMaxBoundaryPoints);
// 2. Fast path: use AABB for text boxes with high aspect ratios
float aspect = std::max(aabb_width, aabb_height) /
std::max(1.0f, std::min(aabb_width, aabb_height));
if (aspect > 2.0f && points.size() > 50) {
// Return axis-aligned bounding box directly
RotatedRect rect;
rect.center_x = (min_x + max_x) / 2.0f;
rect.center_y = (min_y + max_y) / 2.0f;
rect.width = aabb_width;
rect.height = aabb_height;
rect.angle = 0.0f;
return rect;
}
// 3. Convex hull calculation
std::vector<Point> hull = convex_hull(std::vector<Point>(points));
// 4. Rotating calipers: iterate through each edge of the convex hull
float min_area = std::numeric_limits<float>::max();
RotatedRect best_rect;
for (size_t i = 0; i < hull.size(); ++i) {
// Calculate bounding rectangle based on the current edge
float edge_x = hull[j].x - hull[i].x;
float edge_y = hull[j].y - hull[i].y;
// Project all points onto the edge direction and perpendicular direction
project_points_onto_axis(hull, axis1_x, axis1_y, min1, max1);
project_points_onto_axis(hull, axis2_x, axis2_y, min2, max2);
float area = (max1 - min1) * (max2 - min2);
if (area < min_area) {
min_area = area;
// Update optimal rectangle
}
}
return best_rect;
}

The time complexity of this algorithm is O(nlogn)O(n \log n) (convex hull calculation) + O(n)O(n) (rotating calipers), where nn is the number of boundary points. By subsampling to limit nn to within 200, real-time performance is ensured.

Real-time Camera OCR: CameraX and Frame Analysis

The challenge of real-time OCR is how to process each frame as quickly as possible while maintaining a smooth preview.

flowchart TB
subgraph Camera["CameraX Pipeline"]
direction TB
CP[CameraProvider]
PV[Preview UseCase<br/>30 FPS]
IA[ImageAnalysis UseCase<br/>STRATEGY_KEEP_ONLY_LATEST]
end
subgraph Analysis["Frame Analysis Pipeline"]
direction TB
IP[ImageProxy<br/>YUV_420_888]
BM[Bitmap Conversion<br/>RGBA_8888]
JNI[JNI Call<br/>Single Cross-language]
end
subgraph Native["Native OCR"]
direction TB
DET[TextDetector<br/>~45ms GPU]
REC[TextRecognizer<br/>~15ms/line]
RES[OCR Results]
end
subgraph UI["UI Update"]
direction TB
VM[ViewModel<br/>StateFlow]
OV[ResultOverlay<br/>Canvas Drawing]
end
CP --> PV
CP --> IA
IA --> IP --> BM --> JNI
JNI --> DET --> REC --> RES
RES --> VM --> OV

CameraX ImageAnalysis

CameraX is the Android Jetpack camera library, providing the ImageAnalysis use case, which allows us to perform real-time analysis on camera frames:

val imageAnalysis = ImageAnalysis.Builder()
.setTargetResolution(Size(1280, 720))
.setBackpressureStrategy(ImageAnalysis.STRATEGY_KEEP_ONLY_LATEST)
.build()
imageAnalysis.setAnalyzer(executor) { imageProxy ->
val bitmap = imageProxy.toBitmap()
val result = ocrEngine.process(bitmap)
// Update UI
imageProxy.close()
}

The key configuration is STRATEGY_KEEP_ONLY_LATEST: when the analyzer’s processing speed cannot keep up with the camera’s frame rate, old frames are discarded, keeping only the latest one. This ensures the timeliness of OCR results.

Trade-off Between Frame Rate and Latency

On GPU-accelerated devices (my current Snapdragon 870 seems to have issues, consistently failing to offload most computations to the GPU), PPOCRv5-Android can theoretically achieve high processing speeds. However, this doesn’t mean we should process every frame.

Consider a scenario where a user points the camera at a block of text; the text content won’t change in a short period. Performing full OCR on every frame would waste significant computational resources.

An optimization strategy is “change detection”: triggering OCR only when the scene changes significantly. This can be achieved by comparing histograms or feature points of consecutive frames.

Future Outlook: NPU and Quantization

The future of on-device AI lies in NPUs (Neural Processing Units). Compared to GPUs, NPUs are specifically designed for neural network inference and offer higher energy efficiency.

However, the challenge with NPUs is fragmentation. Each chip vendor has its own NPU architecture and SDK:

  • Qualcomm: Hexagon DSP + AI Engine
  • MediaTek: APU
  • Samsung: Exynos NPU
  • Google: Tensor TPU

Android’s NNAPI (Neural Networks API) attempts to provide a unified abstraction layer, but actual results vary. Many NPU features are not exposed through NNAPI, forcing developers to use vendor-specific SDKs.

INT8 Quantization: An Unfinished Battle

FP16 quantization is a conservative choice that loses almost no accuracy. But for extreme performance, INT8 quantization is the next step.

INT8 quantization compresses weights and activations from 32-bit floating point to 8-bit integers, which theoretically provides:

  • 4x model compression.
  • 2-4x inference speedup (depending on hardware).
  • Over 10x speedup on Qualcomm Hexagon DSPs.

This temptation was too great, so I began a long journey into INT8 quantization.

First Attempt: Synthetic Data Calibration

INT8 quantization requires a calibration dataset to determine quantization parameters (Scale and Zero Point). Initially, I took a shortcut and used randomly generated “text-like” images:

# Wrong approach: using random noise for calibration
img = np.ones((h, w, 3), dtype=np.float32) * 0.9
for _ in range(num_lines):
gray_val = np.random.uniform(0.05, 0.3)
img[y:y+line_h, x:x+line_w] = gray_val

The result was disastrous. The model output was all zeros:

Raw FLOAT32 output range: min=0.0000, max=0.0000
Prob map stats: min=0.0000, max=0.0000, mean=0.000000

The quantization tool calculated incorrect parameters based on random noise, causing real image activation values to be truncated.

Second Attempt: Real Image Calibration

I switched to real OCR dataset images: ICDAR2015, TextOCR, and PaddleOCR official samples. I also implemented Letterbox preprocessing to ensure the image distribution during calibration matched that during inference:

def letterbox_image(image, target_size):
"""Resize maintaining aspect ratio, pad remaining parts with gray"""
ih, iw = image.shape[:2]
h, w = target_size
scale = min(w / iw, h / ih)
# ... center paste

The model no longer output all zeros, but the recognition results were still gibberish.

Third Attempt: Fixing Type Handling on the C++ Side

I discovered that the C++ code had issues handling INT8 inputs. The INT8 model expects raw pixel values (0-255), but I was still performing ImageNet normalization (subtracting mean, dividing by variance).

if (input_is_int8_) {
// INT8 model: input raw pixels directly, normalization fused into the first layer
dst[i * 3 + 0] = static_cast<int8_t>(src[i * 4 + 0] ^ 0x80);
} else {
// FP32 model: manual normalization required
// (pixel - mean) / std
}

I also implemented logic to dynamically read quantization parameters instead of hardcoding them:

bool GetQuantizationParams(LiteRtTensor tensor, float* scale, int32_t* zero_point) {
LiteRtQuantization quant;
LiteRtGetTensorQuantization(tensor, &quant);
// ...
}

Final Result: Compromise

After days of debugging, the INT8 model still failed to work correctly. The issues likely stemmed from:

  1. onnx2tf’s quantization implementation: PP-OCRv5 uses some special operator combinations that onnx2tf might not have handled correctly during quantization.
  2. DBNet’s output characteristics: DBNet outputs a probability map with values between 0 and 1; INT8 quantization is particularly sensitive to such small ranges.
  3. Error accumulation in multi-stage models: Detection and recognition models are cascaded, so quantization errors accumulate and amplify.

Let’s analyze the second point further. DBNet’s output passes through a Sigmoid activation, compressing the range to [0, 1]. INT8 quantization uses the following formula:

xquantized=round(xfloatscale)+zero_pointx_{quantized} = \text{round}\left(\frac{x_{float}}{scale}\right) + zero\_point

For values in the [0, 1] range, if the scale is set incorrectly, the quantized values might only occupy a small fraction of the INT8 range [-128, 127], leading to severe precision loss.

# Assume scale = 0.00784 (1/127), zero_point = 0
# Input 0.5 -> round(0.5 / 0.00784) + 0 = 64
# Input 0.1 -> round(0.1 / 0.00784) + 0 = 13
# Input 0.01 -> round(0.01 / 0.00784) + 0 = 1
# Input 0.001 -> round(0.001 / 0.00784) + 0 = 0 # Precision lost!

The threshold for DBNet is usually set to 0.1-0.3, meaning a large number of meaningful probability values (0.1-0.3) can only be represented by 25 integers (13-38) after quantization, resulting in insufficient resolution.

WARNING

INT8 quantization for PP-OCRv5 is a known difficult problem. If you are attempting this, it’s recommended to first ensure the FP32 model works correctly before troubleshooting quantization issues. Alternatively, consider using the official Paddle Lite framework from PaddlePaddle, which has better support for PaddleOCR.

Quantization-Aware Training: The Correct Solution

If INT8 quantization is mandatory, the correct approach is Quantization-Aware Training (QAT) rather than Post-Training Quantization (PTQ).

QAT simulates quantization errors during the training process, allowing the model to learn to adapt to low-precision representations:

# PyTorch QAT Example
import torch.quantization as quant
model = DBNet()
model.qconfig = quant.get_default_qat_qconfig('fbgemm')
model_prepared = quant.prepare_qat(model)
# Normal training, but with fake quantization nodes inserted in forward passes
for epoch in range(num_epochs):
for images, labels in dataloader:
outputs = model_prepared(images) # Includes quantization simulation
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# Convert to a real quantized model
model_quantized = quant.convert(model_prepared)

Unfortunately, the official PP-OCRv5 does not provide QAT-trained models. This means that to obtain a high-quality INT8 model, one would need to perform QAT training from scratch, which is beyond the scope of this project.

Ultimately, I chose to compromise: using FP16 quantization + GPU acceleration instead of INT8 + DSP.

The costs of this decision are:

  • Model size is twice that of INT8.
  • Cannot leverage the ultra-low power consumption of the Hexagon DSP.
  • Inference speed is 2-3x slower than the theoretical optimum.

But the benefits are:

  • Model accuracy is almost identical to FP32.
  • Development cycle is significantly shortened.
  • Code complexity is reduced.

The essence of engineering is trade-offs. Sometimes, “good enough” is more important than “theoretically optimal.”

Conclusion

From PaddlePaddle to TFLite, from DBNet to SVTRv2, and from OpenCL to NEON, the engineering practice of on-device OCR involves knowledge across multiple fields: deep learning, compilers, GPU programming, and mobile development.

The core lesson of this project is that on-device AI is not just about “putting a model on a phone.” It requires:

  1. Deeply understanding the model architecture to convert it correctly.
  2. Familiarity with hardware characteristics to fully utilize accelerators.
  3. Mastery of system programming to implement high-performance native code.
  4. Focus on user experience to find the balance between performance and power consumption.

PPOCRv5-Android is an open-source project that demonstrates how to deploy modern OCR models into actual mobile applications. I hope this article provides some reference for developers with similar needs.

As Google stated at the launch of LiteRT: “Maximum performance, simplified.” 9 The goal of on-device AI is not complexity, but making complexity simple.

Afterword

To be honest, I have been away from the Android field (both professionally and as a hobby) for at least two years. This is the first time I’ve publicly released a relatively mature library on my GitHub secondary account (I’ve handed over my primary account to colleagues to show my determination to move on).

Over the years, my work focus hasn’t actually been in the Android field. I can’t disclose the specifics, but I’ll have the chance to elaborate in the future. In short, it might be difficult for me to make further contributions to Android.

The release of this project was driven by my personal interest—I’m building an early-stage tool based on Android on-device capabilities, and OCR is just a small part of its underlying layer. The full source code will be opened soon (likely very soon), though I can’t reveal more for now.

Anyway, thank you for reading this far, and I look forward to you giving my repository a Star. Thank you!


References

Footnotes

  1. Google AI Edge. “LiteRT: Maximum performance, simplified.” 2024. https://developers.googleblog.com/litert-maximum-performance-simplified/

  2. PaddleOCR Team. “PaddleOCR 3.0 Technical Report.” arXiv:2507.05595, 2025. https://arxiv.org/abs/2507.05595

  3. GitHub Discussion. “Problem while deploying the newest official PP-OCRv5.” PaddleOCR #16100, 2025. https://github.com/PaddlePaddle/PaddleOCR/discussions/16100

  4. Liao, M., et al. “Real-time Scene Text Detection with Differentiable Binarization.” Proceedings of the AAAI Conference on Artificial Intelligence, 2020. https://arxiv.org/abs/1911.08947

  5. Du, Y., et al. “SVTR: Scene Text Recognition with a Single Visual Model.” IJCAI, 2022. https://arxiv.org/abs/2205.00159

  6. Du, Y., et al. “SVTRv2: CTC Beats Encoder-Decoder Models in Scene Text Recognition.” ICCV, 2025. https://arxiv.org/abs/2411.15858 2

  7. TensorFlow Blog. “Even Faster Mobile GPU Inference with OpenCL.” 2020. https://blog.tensorflow.org/2020/08/faster-mobile-gpu-inference-with-opencl.html

  8. ARM Developer. “Neon Intrinsics on Android.” ARM Documentation, 2024. https://developer.arm.com/documentation/101964/latest/

  9. Google AI Edge. “LiteRT Documentation.” 2024. https://ai.google.dev/edge/litert

~
~
mobile/ppocrv5-android.md
$ license --info

License

Unless otherwise stated, all articles and materials on this blog are licensed under Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License (CC BY-NC-SA 4.0)

✓ Copied!