Enhancing Progressive Embedding Alignment for Test-Time Adaptation

Introduction

When a deep learning model gets deployed in the wild, its accuracy silently decays because of a natural phenomenon: data drift. The distribution of the test data gradually shifts away from the training data. Test-time adaptation addresses distribution shift after deployment, where only unlabeled target batches are available. Backpropagation-based TTA methods adapt model parameters or normalization states during inference, which requires gradient computation and intermediate activation storage. This cost is prohibitive for real-time and edge settings, where latency, FLOPs, and memory are fixed system constraints.

Progressive Embedding Alignment [1] (PEA) avoids backpropagation by aligning intermediate embeddings toward source-domain feature statistics. The original method frames domain shift as structured deformation in embedding space and applies source-anchored covariance alignment during inference. Its remaining cost comes from its online procedure: standard PEA uses two forward passes per batch, with the first pass estimating layer-wise alignment weights and the second pass applying alignment.

Although the method doesn’t involve backpropagation, its double-forward approach is practically impractical. What if we could enhance the approach to make it usable in real-world use cases?

The proposed enhancement converts this procedure into a single-forward-pass variant. It preserves PEA’s source-anchored embedding alignment, but replaces the preliminary weight-estimation pass and hard reset logic with statistics-driven alignment gates and continuous adaptive EMA updates.

Progressive Embedding Alignment

PEA models deployment shift as low-order geometric mismatch in intermediate representations: translation or mean shift, scaling or variance shift, and rotation or shearing caused by covariance shift. This interpretation motivates correcting features inside the network while keeping the later unchanged.

For each layer or block , PEA stores source statistics offline:

At test time, target statistics are estimated from incoming batches, optionally through EMA. PEA then applies whitening-coloring transform (WCT) alignment:

Here denotes the current feature tensor at block , is the source-aligned feature, and the source statistics define the alignment target. PEA doesn’t directly replace the feature. It interpolates between the original and aligned representation:

The scalar controls how strongly block is corrected. In the original PEA procedure, a first forward pass estimates these layer-wise weights, then a second forward pass applies WCT-based alignment and prediction.

Single-forward-pass enhancement

Let’s start by removes the preliminary unaligned pass: instead of computing from a separate batch traversal, we can derive alignment weights from maintained target EMA statistics and stored source statistics.

For grouped covariance blocks, write the source and target blocks as

The proposed discrepancy for block is

This combines mean mismatch with grouped covariance mismatch over the same block structure used by the WCT operation. The alignment gate is bounded by an exponential contraction:

The parameter normalizes discrepancy scale across layers, and stabilizes the denominator. The exponential form gives a smooth gate in : near-zero discrepancy produces near-zero intervention, while larger mismatch increases the correction strength. This quantity should be read as a Frobenius-norm proxy for source-target geometric mismatch. It is transport-inspired in the limited sense that it measures displacement of first and second moments, but it isn’t a claim of exact optimal transport. It sets the strength of intervention. The correction direction comes from the source-anchored statistics.

The mean correction direction is

The covariance correction direction is encoded by the whitening-coloring map:

The update should therefore be read as applying a source-anchored affine map in the current layer’s feature coordinates, with setting the blend strength.

Adaptive target statistics

But, we have a problem: test-time batches can be small, noisy, or temporally correlated. Replacing feature statistics abruptly can overreact to one batch, while a fixed slow EMA can lag under abrupt drift. Here, we can use a continuous drift-dependent momentum.

For batch covariance , define the relative drift:

The adaptive momentum is

The target statistics are then updated by EMA:

Stable drift keeps near , giving slow temporal smoothing. Abrupt drift increases , allowing faster adaptation through the same EMA equation rather than a separate hard-reset branch.

Layer-wise corrections and input-space drift

The correction is affine only within a specific intermediate representation. Because it is inserted between nonlinear network blocks, the end-to-end transformation induced in input space isn’t equal to a single global linear transform. The method doesn’t require raw input drift to be globally uniform. It assumes that, after projection through the trained model, a substantial component of the deployment shift appears as low-order mismatch in intermediate representations.

This operating assumption is limited. The method isn’t guaranteed to correct arbitrary class-conditional shifts, label-destroying shifts, class inversions, or cases where semantic clusters collapse or reorder irreversibly. In those regimes, unlabeled global alignment alone may be insufficient and may require covariates, class-conditional statistics, pseudo-labels, or domain-specific supervision.

Experimental setup and results

We evaluate on CIFAR10-C using a ViT-S backbone at corruption severity 5, batch size 512, and Tesla V100 hardware; clean accuracy is 0.8040; accuracy is higher-is-better; mCE, or mean corruption error, is the average corruption error normalized by the no-adaptation baseline, while ECE, or expected calibration error, is the average gap between predicted confidence and empirical accuracy across confidence bins; mCE, ECE, latency, FLOPs, and memory are lower-is-better; the mCE value for no adaptation is set to 1.000 by normalization.

MethodDegraded Acc. ()mCE ()Shift ECE ()Latency ms/img ()FLOPs G ()Memory MB ()
No adaptation0.5801.0000.12570.23759.465735
PEA0.6260.91560.08870.56419.741,681
Enhanced PEA, single pass0.6420.89330.09230.34910.28982

Under this setting, the enhanced PEA variant improves mean corrupted accuracy from 0.6257 to 0.6419 relative to PEA, and improves mCE from 0.9156 to 0.8933. It also reduces system cost: latency decreases from 0.564 to 0.349 ms/image, a reduction of about 38%; FLOPs decrease from 19.74 G to 10.28 G, about 48%; and memory decreases from 1,681 MB to 982 MB, about 42%.

Calibration remains improved relative to no adaptation: shifted ECE is 0.0923 for the single-forward-pass variant versus 0.1257 without adaptation. PEA has slightly lower ECE in this result, at 0.0887. The result supports the single-pass formulation as a practical efficiency improvement in this experiment, rather than a universal solution to all distribution shifts.

Conclusion

PEA’s source-anchored embedding alignment can be made more suitable for real-time and edge deployment by replacing the first forward pass with statistics-driven alignment gates. The enhanced PEA variant remains backpropagation-free, doesn’t update model parameters, and operates on intermediate activations. Its assumption is strongest when deployment shifts remain statistically expressible in learned representation space.

References

[1] Xiao Ma, Young D. Kwon, Pan Zhou, and Dong Ma. “Architecture-Agnostic Test-Time Adaptation via Backprop-Free Embedding Alignment.” ICLR, 2026.

Cite this post

@misc{poiret2026enhancing_pea,
  title = {Enhancing Progressive Embedding Alignment for Test-Time Adaptation},
  author = {Poiret, Clement},
  year = {2026},
  month = jun,
  howpublished = {Technical blog post},
  url = {https://rhizome-labs.com/blog/enhancing-progressive-embedding-alignment},
  urldate = {2026-06-29}
}