I am going to dive deep into a publication by Lin et. al (here), as I believe this contribution is such a state of the art.
The core problem this paper points out is that predicting and understanding the performance of machine learning (ML) training on multi-GPU platforms is extremely difficult due to several interacting factors. These include the complexity of synchronization and load balancing between heterogeneous compute resources (CPUs and GPUs), the variability in input data distribution, and the heterogeneity of communication infrastructure, such as NVLink, PCIe, and network interconnects. Most existing performance models are either platform-specific, limited to single-GPU settings, or fail to generalize across different workloads and hardware configurations. As a result, it's very challenging to accurately estimate per-iteration training time, optimize training configurations (like how to shard embedding tables), or make informed deployment decisions without extensive trial-and-error runs on actual hardware. In short, this paper proposes a statistical and simulation-based model that accurately predicts the iteration time (i.e., training step time) of DLRM (Deep Learning Recommendation Models) and NLP workloads running on multi-GPU systems. I stopped here and asked myself "so what ? Why the prediction of iteration time is important ? Who cares how many iterations it take ? " To find the answer of these questions, I had to do an extensive research about the intrinsic and inherent behavior of ML workloads. Long story short, below are the reasons:
Predicting iteration time in multi-GPU training is crucial not for its own sake, but because it directly informs key decisions in large-scale ML systems where time, cost, and efficiency are important. Each training run consists of millions of iterations, and even small increases in per-iteration latency can translate to hours or days of extra compute, costing thousands of dollars on large clusters. Accurate iteration time prediction enables engineers to evaluate different sharding strategies, parallelization schemes, and hardware configurations without running expensive experiments, thereby accelerating system design and tuning. It also supports intelligent resource scheduling, allowing platforms to allocate GPUs more effectively and meet performance guarantees. In essence, iteration time is the atomic unit of training performance, and being able to predict it with high accuracy enables better decisions across the entire ML training stack from cost estimation and throughput optimization to infrastructure scaling and workload scheduling.
Before dig deeper, I find it helpful to explain a few terminologies:
1- Input data distribution
The way data is divided and assigned to different GPUs (or nodes) can significantly affect training performance. If some GPUs receive more data than others (say, due to non-uniform sampling, shuffling, or padding), they take longer to finish processing their mini-batches. This causes load imbalance; the faster GPUs must wait for the slowest one at synchronization points (like all-reduce or barrier ops), wasting compute cycles. Especially in recommendation systems or embedding-heavy models, some GPUs may handle more frequent or "hotter" features (e.g., popular user IDs or categories), resulting in heavier memory access or communication load. This again disrupts parallelism efficiency. When large datasets are partitioned across GPUs (such as sharded embedding tables), how you divide the dataset can impact throughput. An imbalanced partition leads to inconsistent training times per step, communication overhead, and memory pressure. This reminds me of CTA allocation policy. CTA (Cooperative Thread Array) allocation policy and input data distribution variability are related, but they tackle different layers of the system, and it’s important to distinguish them. Both CTA allocation and input data distribution affect how work is assigned across GPUs (or SMs within a GPU), and both can lead to load imbalance, but they operate at different abstraction levels. Input Data Distribution refers to how the training dataset itself is divided across multiple GPUs. For example, In data parallel training, if one GPU gets more samples or more complex samples than another, it takes longer to compute. So this is about the data, not the execution model. CTA allocation policy is about how work is distributed across streaming multiprocessors (SMs) inside a GPU or across GPUs in some multi-GPU scheduling schemes. When a kernel is launched (e.g., matrix multiplication, embedding lookup), the scheduler decides which GPU or which SM gets how many CTAs. So CTA policy is about how computation kernels are scheduled, not how the input data is divided. If input data distribution is unbalanced, then even the best CTA scheduling won't fix global performance imbalance — because one GPU might still get way more to do. On the flip side, even if data is perfectly balanced, poor CTA allocation could cause underutilization. In this paper’s context, variability in input data distribution refers to data-level imbalances across GPUs, which then propagate into workload execution, synchronization overhead, and communication cost.
2- Per-iteration training time
Machine learning training is iterative, typically involving many thousands of iterations, where each iteration processes a mini-batch of data, performs forward and backward passes, and updates the model. The per-iteration training time refers to how long a single one of these iterations takes to complete, across all GPUs and all necessary compute and communication steps. If you can estimate how long each iteration will take on a given configuration (e.g., number of GPUs, batch size, sharding method, communication backend), you can select the best setup without running trial-and-error experiments. In production systems or cloud environments (e.g., AWS, Azure), knowing per-iteration time lets you predict total training cost, time-to-train (TTT), and decide how much hardware you need. Especially useful for large models like GPT, BERT, or DLRMs that train over weeks. If your per-iteration time suddenly increases or doesn't scale with added GPUs, it could point to a bottleneck in compute, memory, or communication. A good model helps you pinpoint the culprit: is it GPU saturation? Poor PCIe bandwidth? Imbalanced data? Frameworks like PyTorch, TensorFlow, or Ray use automated schedulers to allocate resources. Accurate per-iteration estimation helps these systems make smarter allocation decisions in real-time. Want to know if switching to NVLink from PCIe will help? Or how training will scale from 4 to 16 GPUs? Accurate per-iteration estimation gives you insight without rerunning the whole job.
The authors propose a universal, modular performance modeling framework designed to accurately estimate per-iteration training time without requiring full execution. At its core, the solution integrates several key components that, together, simulate the actual cost of one training iteration with high accuracy. These components are data-distribution-aware, meaning they take into account how the data is split and used across GPUs.
I am particularly intrigued by the below picture illustrating Per-GPU-stream training execution time breakdown of selected ML workloads on a 4-GPU platform. This shows that:
DLRM behavior: These set of workloads exhibits long communication-dominated critical path, especially due to all-to-all ops. They are bottlenecked by memory access patterns, communication, and synchronization — not raw compute.
NLP behavior: NLP workloads are dominated by compute and have simpler communication patterns (all-reduce), mostly at known points.
This exploration is important because the modeling challenges are different for each workload type. This shows that a uniform modeling strategy won’t work and we need workload-specific sensitivity to accurately simulate iteration time.
Here are the main components of the modeling framework:
1- Communication Collective Model: The Communication Collectives Model predicts how long it takes to perform all the required data exchanges between GPUs during one iteration, based on: The size of messages, The type of collective ops, The topology and bandwidth of your hardware, and feeds this into the overall per-iteration performance prediction. For example, in one iteration of a training model, they see three communication type as follow:
Each of these is independently modeled using its fitted sigmoid function.
2- Inter-Rank and Intra-Rank Synchronization Model: In this context, rank means GPU. The time the entire system should wait for all ranks to reach a sync point (e.g., after compute or communication), before proceeding to the next step. In practice, not all GPUs or nodes finish at the same time. Some might be doing slightly more computation, Have slower communication paths (e.g., PCIe vs. NVLink), Experience contention, Or simply be assigned harder data. As a result, faster ranks must wait for the slowest rank, which wastes time and adds non-trivial overhead to each iteration. This is called straggler effect or tail latency. This is also true within a rank, among SMs.
3- Embedding Lookup Time Model: In machine learning, an embedding is just a vector representation of something that is originally not a number, like a word, a user ID, a movie title, or a product. Imagine you have the word "apple". You can't feed this into a neural network directly, so instead you look it up in a table that gives you a 128-dimensional or 512-dimensional vector that represents "apple" numerically. This table is called an embedding table, and it could be very big; sometimes millions of rows for things like users or products. Simply, The amount of time it takes for the system (especially the GPU) to fetch the right vectors from the embedding table called embedding lookup time. In some ML models like recommendation systems, embedding lookups are the biggest chunk of work. They dominate both compute time and memory usage. The model uses this kind of information to estimate how long embedding lookups will take on each GPU, and then feeds that into the overall training time prediction. Instead of hand-crafting an analytical latency model for embedding lookups, the authors train a supervised regression model that predicts lookup latency based on observed features. This is the most statistical/learned part of their performance modeling pipeline.
In the evaluation phase, the authors tested their performance modeling framework on two classes of machine learning models: embedding-heavy recommendation systems (DLRMs) and compute-intensive Transformer-based NLP models, across two different multi-GPU platforms with varied interconnects, including NVLink and PCIe. The results demonstrated that the framework was highly effective. For DLRM models, which involve significant embedding lookup and communication costs, the model achieved a geometric mean prediction error of 5.21%. For Transformer models, where compute dominates and communication is more structured, the error was even lower, at 3.00%. These error margins are considered impressively low, especially given the diversity of workloads and hardware. Beyond simple timing predictions, the model was tested for its ability to recommend optimal configurations without needing to run the full workload. In particular, the authors evaluated whether the framework could predict the best way to shard large embedding tables. The framework succeeded in identifying the fastest sharding configuration 85% of the time, demonstrating that it could guide critical system-level decisions without requiring empirical trial and error.
Open question for me:
1- Why the paper only considers DLRM and NLP ? There are other ML/DL workload types as well.
ُThoughts and opinions:
This paper stands from the point of view of the application rather than hardware. It is logically convincing to predict the latency of a single iteration of a ML/DL workload. However, I am personally more interested in predicting the performance (IPC or throughput) of the system. That's why I come up with a few ideas which I believe are potential in delivering appealing results by leveraging core idea and insights from this paper.