In Memory/Storage Compute to maximize Performance, Security and Reliability for AI Distributed workloads

This post will characterize importance of in memory compute in current state of infrastructure mainly for large, distributed AI training and Inference workloads especially when Moore law is marching to its end of life. Technology shrink past 2nm is getting costlier and harder due to CMOS lithography sophistication. AI workloads as per current generation require very dense and distributed compute pipelines for tensor parallelism, expert parallelism (Mix. of experts), pipeline parallelism in addition to data level parallelism in order to maintain most optimal balance of compute and communication ratios and large memory accesses from these processing elements. This is mainly due to performance requirement of tokens/hours throughput and latency sensitivity for inference workloads using “custom data structures” like KV cache laid out in distributed instances in near/far memories and storage to accomplish various optimizations needed for LLM architecture. Quantum computing is still way out in time. Memory hierarchy is complex and getting costlier.

Requirement of increasing compute density using either heavily multithreaded hardware (e.g: SIMD multiprocessors in GPU’s called Streaming multiprocessors or Compute Units in commercial offerings) or 1000’s of fixed pipeline-based tensor processing hardware accelerators (e.g: systolic arrays) needed for compute in AI workloads to improve throughput(e.g: in MHA or MOE based transformer models to improve GEMM /convolution/other math style operators performance) or data precision transformations to maximize available capacity utilization of fast/slow memory hierarchy and fast/slow storage hierarchies. Various optimizations like mixed precision data types(fp4/8/16,BF16 and similar for integers), compression formats like CSR (e.g: compressed rows) to speed up GEMM computations or provisioning and accessing a larger KV cache for inference operation like prefill and decode related to prompt tokens to optimize metrics TTFT (time to first token) and inter token latencies are being applied in a very distributed architecture both on silicon and platform level. Silicon Commercial offerings are trying various optimizations like 3D TSV (thru silicon based via) based structures for scaling on chip memory capacity for caches/constant and shared memories, multiple chiplets connected to main base die using proprietary or UCIe type of high speed interconnects using PAM4 type signaling, fancy foundry specific non-scalable joins to connect multiple dies for faster time To Market and finally creating small form factor PCB modules housing processing elements, memory modules and other ancillary components to create piggyback boards on motherboard when all else fails. These are silicon level optimizations and another significant effort underway to improve utilization using compute and communication libraries in such instances.

One fact remains constant in all of above optimizations to maintain concurrency and data availability to such large number of compute pipelines requiring heavy memory read and write interactions. From performance angle all architectures try to create maximum overlap between compute and communication primitives required for various kind of parallelism using their low-level libraries. Thus, high bandwidth of memory access is paramount in addition to reduced latency for high compute throughput. Due to distributed nature of AI workloads, there are considerable communication collective invocations overhead by workloads (e.g: collective library API calls) that require “all to all”, “all reduce”, “all gather” and “reduce-scatter” type of network communication patterns accomplished using available memory capacity/bandwidth and network speeds/feeds. Most of these collectives are accomplished using RDMA style accesses using underlying “Ethernet or IB fabric” connecting multiple processing elements terminating into end point memories and requires processing associated communication verbs. In memory compute pipelines can alleviate some of these overheads. Besides these collectives, KV cache scaling requirement needs IO access to fast NVMe style storage to bring data into faster tier of memory (e.g: DDRx or HBMx) for prefill and decode stages to accomplish TTFT and TTOT (time to output token) metrics for given input prompt and associated context related inference performance since requirement of KV cache capacity grows very rapidly with sequence length in input token prompts.

In memory/storage compute can speed up training and inference in many ways by offloading some of the compute near to memory or storage, following may not be exhaustive list but illustrative enough:

  • Speedup KV cache accesses (e.g: offload calculation and hashes of indices, predicates computations on accessed data)
  • Inline Data layout transformations (e.g: fp16->fp32 or vice-versa) while retrieving or storing.
  • Inline compression/de-compression of data being accessed to save capacity, this is not just format conversions like fpX to fpY. This can create a Cold/Warm/Hot hierarchy of data within memory address space based on compression and can have large impact on memory capacity reduction.
  • More intrusive and high value add operations like MatMul operations/transpose of vectors and matrices/convolution of filters onto input matrices are considered very worthwhile and integrated in many Silicon Valley outfits doing boards and system on chips for inference workloads these days. This will have large impact on memory bandwidth.
  • Facilitate inline stored data structure related operations like insertion/deletion/update operations saving main compute pipeline overhead and data movements.
  • Data movement operations within memory or update large address ranges can be embedded as well like MEMSET/MEMCPY/MEM-MOVE
  • Inline Encryption and Decryption of data bound to network interfaces or Specialization of security in memory and storage regions for data of multi-tenant clients require compute of Hashes and verification checks using private keys.
  • And I can keep going to add in storage operations as well creating a very exhaustive list…when combined with in memory can relieve the main multithreaded pipelines lot of data movement, state upkeep and un-necessary access memory bandwidth with some tradeoffs in cost of platform/silicon/software implementations. These in memory/storage compute helper functions not only add value to performance/security of inference and training pipelines but a huge cost reducer.

One can also blur the so-called CPU and Accelerator boundaries with optimizations listed above and thus reduce the synchronization overheads. Above listed optimization for in memory and storage compute needs to be tightly integrated with existing or new libraries in the given AI inference or training stack (e.g: Communication collectives, DNN and Linear algebra libraries, KV cache related frameworks SGLANG/LMCache or NIXL..). Main tradeoff is about building flexibility of in-memory/storage compute functions to target them not only primary use cases but also to secondary/tertiary cases. For instance, integrated in memory compute to HBM-x hierarchy can benefit DMA operations related data updates to ring buffers and associated meta data for network or IO storage traffic as secondary use case whereas primary use case might be doing format conversion of data from local threads or accelerators for retrieving or storing into memory.

Reliability aspects necessitate calculations like checksum/CRC/ECC code with impact on significant compute and effect memory access bandwidth and latency performance in case of errors. These functions can be easily integrated as part of in memory/storage compute offload functions in addition to performance enhancing functions listed before. This may include Reed-Solomon or BCH codes encoding and decode in Galoise field to detect errors in accessed locations or parts. This can add significant value to observed IO and Network access bandwidth to these memories as well as to storage capabilities. Autonomous error handling requires RDMA retries and managing redundancy provisioning in memory capacity besides links and lanes/communication-link management to avoid severe shortfall in functional performance of these systems. At a system/server level one must have a mechanism to initiate graceful failover mechanism after detecting errors which requires some form of checkpointing mechanism and in memory/storage function offloading can help significantly failover with minimal loss. There is lot of inspiration here that can be derived from techniques used in large cpu NUMA based fault tolerant systems from yester years. Another related source of large number of memory accesses are DMAs from distributed storage modules and could be characterized as IO traffic (arriving on e.g: PCIe v6/7/8, NVMe to individual socket) to load working sharded snapshots from fast SSD or Disk based storage into and out of DDR based memories. There are lot of finer details of interactions with memory using near atomics and similar primitives to keep pages in memory intact (e.g: Pinned memory) and not to be thrashed which are out of scope of this post. But it suffices to say that managing IO bandwidth and managing storage objects in near and far distributed storage has large impact on the actionable content in memory brought in using DMAs on IOs. These IO’s also contribute to demand on memory capacity and access bandwidth which can be alleviated by near memory and storage offload functions.

Another big use case is DB management, current generation of AI workloads using transformer models use vector databases to store/retrieve embeddings from (e.g: FAISS/Pinecone….) which are primarily KV-store DBs as used in KV caches and need to be resident primarily in attached memory. Similarity searches executed on these key-value stores need to be fast on un-structured data like embeddings that could have wide rows. Key metrics like tokens processed/hour could suffer in case these key-value stores in memory are not managed efficiently. In memory compute functions offload for such tasks can free up main compute pipelines for DB update management.

In order to see in memory compute requirement and see the benefit of offload functions, one may consider workout a hypothetical example with popular accelerator available in market like GB200. Note, current memory provisioning is mostly hierarchical. Besides internal caches (L1, L2, shared, Constant, scratchpad…) and registers memories use DRAM/DDR-x protocols and HBM-x using wide IO links for high bandwidth access to populate these internal memories. DDR-x(x=5/6) memories are mainly used to provision high capacity and HBM-x(x=3e/4) is used to provide fast access which is also capacity limited (e.g: HBM4 16 high stack delivers 64GB capacity with 2048 pin per stack in 16 layers of dram). It is necessary to know capacity and bandwidth tradeoffs for a given processing element like accelerator with estimate of feasible bandwidth from external memories in worst case of no reuse of its internal memories to correlate to compute requirements or deciding offload functions. For instance, GB200 claims 32 PFLOPS/sec for FP8 data type and if we provision 8 stacks of upcoming HBM4 with it, provisioning will severely fall short in case all caches and registers are ignored and all required bandwidth is sourced from HBM4 stacks assuming capacity is sufficient. Each HBM4 delivers 2TB/sec and 16 TB/sec for 8 such stacks connected to accelerator. If we derive arithmetic intensity of 8-Flop using 1 Byte then to deliver 32 PFLOPS/sec would need 4PB/sec (32/8) bandwidth. This bandwidth cannot be sourced (required 4PB/sec and available 16TB/sec) using state of art 8 stacks of HBM4.For this to work most of bandwidth would be sourced by internal registers and SRAM (L1, L2, Shared…) within GPU or Reuse of available data. We would need 256 instances of 8 stacks of HBM4 in order to deliver 4PB/sec (256 * 16TB/sec) in case use of GPU internal memories not considered. Therefore, this hypothetical assumption makes it clear without data reuse in internal memories, a very complex processing element surrounded by large high bandwidth memory would not deliver compute performance (32PFLOPs/sec on FP8 with 8 Flops/Byte) even surrounded by large number of HBM’s (in this case 8).

Finally, In order to figure out compute and offload functions, is to start replacing functions performed by internal pipelines (threads) by profiling and breaking down the ISA sequence length and loop counts. Then consider replacing some of profiled sequences using offload functions in near memory compute pipelines and evaluate again impact on mismatch in required and delivered using this toy example (required 4PB/sec originally and delivered 4TB/s using 8 stacks like 1000:1 with internal memories). This has two benefits, one it might reduce cost of memory (in this example we needed 256 instances vs 8 provisioned hypothetically), second it takes away cycles from main compute pipeline for functions offloaded to memory compute thus improving arithmetic intensity and better overlap of compute and communication. I can make similar arguments made for memory offload to storage offload as well which would cut down DMA/RDMA cost to fast and slow connected memories.

Lastly there are solutions in market addressing some of above highlighted in this post, but they are very intrusive and costly implementations since attempt has been made to very tightly integrate the most popular primitives (e.g: MatMul) with SRAM design which in some sense not future proof and scalable easily. In memory/storage compute functions need to have flexibility, scalability and fungibility built in from start.