AWS re:Invent 2024 - Conquer AI performance, cost, and scale with AWS AI chips (CMP209)

AWS re:Invent 2024 - Conquer AI performance, cost, and scale with AWS AI chips (CMP209)

Show Video

welcome everybody to conquer AI performance cost and scale with AWS AI chips my name is Gad hoot I run product and customer engineering at auna labs and with me we have three amazing speakers Yang from the Amazon Rufus team H Joe from the poolside team and Matt from Google deep mind we are all here to cover the latest announcements on trainum to and give you updates on training and infen in general we have a lot to talk about let's get started I guess we are all here because we have some interest in Ai and geni specifically as in technology and how it enables Innovation like Matt our CEO says the the job here at AWS is for us to enable customers to have access to this amazing technology as securely performant and scalably possible what it means for us the team at Anapa Labs that build the chips the boards and the software around them is we need to keep innovating on your behalf to make sure we can deliver the SC and the performance you guys all expect we started this journey back in 2016 literally a few engineers in a room trying to figure out how to build ml systems that are high performance and cost effective this has been the leading tenant for us ever since providing those uh providing the performance at at a at a reasonable cost and we've been delivering the first we delivered the first instances in onean that had 16 influencia one chips back in 20 uh 2019 and we've been delivering ever since and of course the latest announcement this morning uh Trum 2 in general availability with the tier and two instances and the preview of the tier and two ultra server before we dive into the technology I want to explain why we keep Building Systems that are larger more performant I think most of you has prob have probably seen this graph before basically what it shows is that over the last decade or so models have been growing at an exponential rate now the question we need to ask ourselves why is this happening well the answer is pretty straightforward research behind it is not that way forward but research shows clear evidence that the with with model size growing the the accuracy and the perform performance of the model improves what you see here these graphs are from a famous paper back in 2020 the scaling laws of neural uh language models and the study shows that per performance improved predictably as long as the compute the data uh the training data and the model size itself grow at this at a similar rate let's zoom in to the performance aspect because that impacts the infrastructure what you see here is something very interesting this is an log on log graph what it means is that the line the straight line you see here intuitively you would think it's addition patterns but it's actually a multiplication pattern so if a data scientist want to improve proove their Baseline model by 50% they have to invest 10,000x more compute in order to reach that result 10,000x what it means for us when we build those systems we of course need to keep the keep improving performance but we have to keep power at Bay if we won't do that we won't be able to scale to your needs so back H so what does it mean for us in the for future models the trends we see from customers are very clear the model sizes will continue to grow and therefore we need to keep up and this is where Trum 2 comes uh to uh comes for the rescue Trum 2 is our most complex chip we designed to date it packs 1.3 peda flops of dense comput just to give you reference the first instance we launched only 5 years ago a single chip of Trum 2 has 30% more compute than that entire server and we also added Innovation here for example we are the we are the only solution in the industry that provides 4X sparsity which means that if your model is sparse you can get up to 4X boost in your compute uh available compute let's look how how this all looks inside the server this is the tranium two server it packs a whooping 20.8 pflops of compute 46 terabytes per second of hbm bandwidth 1.5 terab of hbm memory what it means as compared to other options on WS simply put T2 has 30% more compute than the latest GPU instance the p5n compute is very important for training and inference but also hbm hbm memory has to keep up and it because it has a dramatic impact on the model performance so also on hbm we have 30% more hbm than is available on the p5n and of course we are building those in large the largest Ultra clusters we've built to date let's look at performance these are results that were verified by artificial analysis it's an independent um benchmarking site and you can see here that Trum 2 on the y- axis you see latency or time to First token on the xaxis you see throughput um or output speed and you can see that trainum 2 is more than 3x the throughput of the next available solution from cloud providers another great example is coming form anthropic what you see here are results again verified artificial analysis of the Hau 3.5 model compared to TRN one TRN 2 is 60% more performant compared to the other cloud provider that is deploying hiu 3.5

the result is 41% faster truly amazing performance with performance in mind I want to hand it over to Young from the Rufus team to tell you how Rufus Amazon roof was scaled within franch tranium across the world bu let me know if you can hear me awesome I'm Yan uh so I'm a software engineer from R steam I work on influence optimization so today I'm excited to share our journey on how we build optimize and overcome challenges on our way to scale roofers to handle large scale Amazon scale traffic so let's dive in and all we are using a AI chips all the way so Rufus is a system that answer can answer customer shopping questions and then using and then the model that is powered by a large Lang model trained by rofus team from scratch and then the model is also trained using AWS using Amazon's product information and also information across the web so as you can see from this example I think it's not showing but um as you can see this example where the customer ask a question how to choose you know how to choose a board game then within a second the customer the custo the the model will generate response that have the um that have the answer that customer need and within the response there's also the um we also embedded the links and images that can help customer um continue explore with rofus so since rofus first launched in this year in July we have successfully supported multiple major events uh Peak events so for example in in this year July during the prime day we have reported um millions of customer shopping with roofers using an as um using a cluster that deployed with 80 more than 880,000 infu and trinium chips and then back to back in this this year in August we supported uh a even larger traffic during the prime big deal days where we during the peak hours we generated 3 million tokens per second and then we during this two-day event we also generate 320 billion tokens uh in total so when we start to productionize roofers and then uh in order to serve tens of millions of requests per sec tens of thousands of requests per second using thousand of host we found the trinium and influential powered tr1 inf2 chip to be our best option and here are why so when comparing with the first reason is when comparing with GPU host with similar performance the the TR one and inf2 host have better performance with the high-speed crost chip communication and then additionally they also have a larger accelerator memories that can enable us to host even larger model in the future further the tier on one in to host is also very cost efficient so that um and also we can scale them and they are highly available around the world so we can scale them and use them to Serve customer traffic worldwide so next let's take a look at our system architecture we use AWS ECS to deploy our service and then with with service in each region capable of launching uh instances using both Ste one inf2 we can scale up the traffic whenever there's a peak burst and then in the meantime our service resiliency is also very high because we use cross region hosting within the ECS service in each region we build and deploy containers that include Tron server powered by VM and eight of us new SDK this setup best position us to to be leveraging the optimization coming from vom and neur SDK together in the meantime we we can also leverate and get benefit from the Abundant um inference server inference server functionality available from Triton So based on the profiling of our first end results we found the performance is not optimal and then in order to achieve a more smooth customer experience we Define a new latency goal which is quite um aggressive I would say given our model size and before I I introduced what are optimizations that help us bridge the gap I would like to first introduce how LM influence working a quick view so for example a customer ask a question on what to buy for camping and then the in in less than one second the model returns like when when camping you know what to do using the Insight on and also there's a recommendation included so what's what's happening into the in the model within during this process on the high level the LM inference can be separated into two stages context encoding and decoding so context encoding involve the model to build contexts around thousands of input tokens and process them at the same time before the first token is even generated and then decoding involves the um predict that token sequentially where the output token from the previous step is appended to the input to the current step in order to generate next next token and this process repeats until the stopping criteria is met so the context encoding in requires processing many tokens at the same time but and they can reuse the model weights so this process usually compute balce meaning if we have it will go faster if we have more compute PL flots available looking at decoding because we have KV cach that help us save a lot of intermediate results it reduce a lot of compute but in the but we still need to move the model weights around for whenever we generate the token one by one so in this in this for this process it usually can go faster if we have a larger model bandwidth a lar larger memory bandwidth so now let's look at a few critical op optimation that help roofers to speed up first is streaming so instead of returning the returning token at the very end once they they are all generated we return them right after each of them is generated so comparing with the returning all of them at once at the very end the overall First Response latency was reduced from 3 to 15 second depend on response length down to 1 second in addition the Stream response is also smartly aggregated with links and image previews that inserted into the response that can enable customer to continue shop and engage with rofus another optimization we have we enabled is called multi prompt so for example we previously used one prompt to ask a model to say hey J not only generate the white response which is highlighted here and also but also the related question which is the blue PE is highlighted there and then with multi promp we break the promp into multiple into initial prompt and set of follow prompts so using this Pro and then um with this with this um process the initial prompt have less tokens therefore the context encoding process become faster so further by only inserting the followup prompt after the response of the initial prompt is generated we also we also manage to further improve the model quality because there's better prompt following recall that LM is inference is especially the cing phase is bound by the memory memory bound with so one of the approach to mitigate this issue is to use quantization by compressing the model weights for example with a model that have 80 uh 70 billion parameters it will usually take around 140 gigabytes to store the model me to store the model weights if they're represented in in a 16bit floating points however if we quantize them into 8bit integers same model model weights will only take around half of the size so with last size the memory movement as we mentioned earlier is also faster so in case of rofus we reduced the both contact in coding and decoding latency by using by quantizing our model weights from fp16 to in8 um and then um yeah that's the optimization besides the optimization I just mentioned rofus inference also got continued benefit from the neuron compiler level Improvement that are incorporated as each of the official release and with all this optimization in place we achieve our latency goal and all the optimization I mentioned is already available as part of the neuron SDK as of today so next I would like to share our lessons learned when we scale up our service so um first there's there so there are a couple challenges the first one is in in case of General web web applications in the in today usually one service one single can handle around you know hundreds or even thousands of requests per second but when it comes to large language model application um because of a huge Compu and memory require that are that is required single host may only can handle you know tens or even single digit request per second the second constraint is the end to and latency so in case of rofus based on the customer questions the response lens is different for example if the customer ask is this product waterproof then the response will be shorter it take around 3 seconds but if a customer ask a more broad question to say hey how to buy what to buy for camping then response will have not only that Insight but also additional recommendation to the customer and it will take a little bit longer time to generate the full thing so with the variable latency and also the the low host the the low single host host concurrency Distributing request evenly using the you know traditional run robing load balancing mechanism does not work well because when because some host in this case may be just running very very hard by crunching through the long request but some other Hol are just running with very less load so with this setup this lead to the overall system availability starting to drop earlier than we thought without Theo all the holes within the fleet to be fully utilized so to overcome this challenge we leverage least outstanding request algorithms available from the AWS application load balancer so this algorithm enable us to distribute the request based on how busy each host is and in case of Rus compar with doing load balancing with run robing this help improve our overall fleetwise throughput by 20% another optimization that help roof us to further improve our throughput but and with with the same size of machine with same size of cluster is continuous batching for example suppose we want to Pro process five request as showing in this example example so without batching we have to process them one by one and then with static patching the server wait for a c certain period of time to accumulate to accumulate them into a batch in order to process them in parallel and the batch size is configured as a maximum batch size in uh pre uh in advance for example if we use Ma maximum batch size three in this case then request one 12 three will have to get processed first and finished uh before we can start request four and five and apparently there's some waiting times which will introduce a lot of uh some waste of the computer resource so to further optimize we use continuous batching where the server batch sequence at the level of token Generation Now each request can return response whenever it is gener whenever it is completed and for example as long as request two is finished then we can start inserting request four and start processing it with continuous patching Rufus team was able to improve our single host throughput by 4X and this optimization is is already available from vom and the neuron SDK integration today so looking ahead I'm super excited to what's coming next and then to name a few we will leverage the new compute functionality available from the trinian two host to uh for example the low beat quantization compute and also sparse compute to power our next Generation model in addition we will also use the Niki kernel to develop Advanced algorithm to help us process hundreds or thousands or millions of tokens you with with shorter time and then last but not the least we will continue contribute to the optimization available uh using VM and neuron SDK that's all thank you so much and back to G thank you thank you thank you very much thank you young we are looking forward to continue working with you and the team to bring performance to all of Amazon customers super exciting work um so let's go back to our model um to our model um uh compute needs we mentioned before before uus and yang talked to I mentioned that tier and 2 is 30% more comput and more memory than the latest GPU machine this is certainly very exciting but what we also realized is that customer came come to us and ask for hey I want to train a one trillion two trillion even 5 trillion parameter model we might be leaving some performance on the table so and and the reason for that that is that even the largest instance we can build will not have optimized compute and memory to be able to uh efficiently train and deploy these monster size models this is where Ultra server comes into play please welcome the ultra server an ultra server is essentially four tier and tws combined together with high-speed neuron link bandwidth to deliver a whooping over a performance of 80 paa flops uh of dense computer over 300 P flops of uh sparse compute 185 terab per second of hbm bandwidth and 6 terab of hbm essentially you get 4X of the performance in a high performance compute node of course the highest that is available on AWS and we are already using those Ultra servers to build project reineer together with anthropic project reer will have hundreds of thousands hundreds of thousands of trinium to chips working in a single cluster for providing more than five EXA flops of available compute uh which is 5X larger in EXA flops than what anthropic have access to today this amazing amount of compute will give the anthropic data scientists and and Engineers the ability to innovate at an unprecedent scale with the largest model they built to date we are very very excited to collaborate with anthropic on this and it's coming very soon but all of this performance and unlocking that amazing performance requires an SDK that simply works and enables maximum performance with code portability and usability in mind let's talk about the neuron SDK the neuronist Decay supports a wide array of Frameworks libraries and services at the bottom of the stack is open xlaa open xla is a unified compiler uh open source project we co-funded together with Google and meta to ensure that customers get uh run can run their Trum code and have maximum uh code portability when they do so um TR is also fully supported with popular open source libraries like hugging face and py torch lightning and Ray and as well as well as observability platforms like data dog and the high performance inference uh projects like you heard from the ruus team like VM and many many more the neuron SDK software is built from a few major components namely the compiler the runtime stack and the framew framework integration Plus usability and debuggability tools you probably heard yesterday from the keynote Mite keynote for anthropic that this is the only uh accelerator that have access to cycle accurate data when they run their profiling H this is all enabled by the neuron profiler um the architecture lever leverages Open Standards when possible H and supporting performance with distributed training and inference we'll talk about Jacks in length soon but I want to touch also on some Innovation we have uh implemented since last invent on pyo specifically the neuron distributed uh Library neuron X distributed or in short nxd for training and inference let's start with training nxd training is an open-source pytorch library that is fully optimized for large scale model training it combines ease of use and uh Advanced sub capabilities for large scale training with core models for engineering and data science a wide support for for all of the training workflows uh pre-training finetuning and and techniques optimization techniques like Laura and model alignment and it's all open source built on top of py let's take a look at a quick example of results done by nxd training we took here a pixart model which is an open source model for image generation and we asked the model to train a a blue jay standing on a large basket of rainbow macaroons don't ask me why that's the prompt I thought we thought it's going to be just colorful and nice uh oops I had some nice animation here that went away but anyway what I'm what we are showing here is that the uh image starts with a blur and then of course improves as training uh progresses with a very nice result after the model is fully fine-tuned at the bottom right let's talk about nxd inference similar to the building blocks we used for training we have a um a similar approach for for inference of with py toch with large models can be llms can be multimodel models and can be computer vision models the library provides performance features like continuous batching um speculative decoding uh and also has a model hub for uh production ready models like Lama 3.1 3.2 dbrx and Mixr H and multimodel uh models like lava all optimize on top of trainum and infr and ready to go it's similar to the model training it has a modular architecture that enables you to bring your own model into nxd inference and reuse all the features uh quite easily and of course an expansive support for uh training libraries we mentioned VM couple of times but also a um hugging face TGI Ray serve lit serve from P torch lightning um as well as the Nvidia Tron model serving we worked really hard on nxd together with the meta team uh uh to ensure that it's user friendly and straightforward to use I encourage you all to try it uh soon the usability and ease of use is super important for us so we figured out try we try to think how can we actually improve further usability for our customers this is where I'm very happy to announce newon expert for with Amazon Q developer neuron expert is a virtual solution architect at your the tip of your hands it allows you to ask general questions about neuron and get quick references let's see how it works here we are asking the model to tell us uh how to Shard models with neuron in seconds you get a detailed answer that has also reference links to so you can deepen your knowledge uh ask follow on questions and engage with the engage with Q quite easily um we this has been really impactful for us to reduce the load of support but also allows you guys to uh run quickly and this is again available today in C developer either on your console or with your favorite code editor but we wanted to do a bit more uh we thought that this is cute and this will help help you guys a lot but we can do more last year as a reminder on this stage I announced Nikki our neuron kernel interface which allows developers to build performant uh compute kernels customized compute Kels directly on top of training Nikki has two uh uh two apis one of them is called Niki Isa which is the lowlevel uh access to the instruction set of the chip you actually can write the use the chip as a bare metal chip and have full control of all the engines on the chip that's one option another option is more aligned with um H uh with it's called Niki language and it's aligned with the syntax and ti ti level semantics that are available today on open open AI Tron as well as numpy developers are using Nikki very successfully you hear about it soon from pool side but we thought that we can do more to get more folks familiar with Nikki that's why we you can actually generate Nikki code with Q developer in this example a simple prompt asking the asking the model to generate a m multiple a mat mold kernel the code is again generated in seconds including documentation including an explanation of what was built and from from that point on you can copy copy paste that into your favorite code editor vs code or other and start your project this is a fun fact about what you see here A Q developer the model behind Q developer is anthropic clo and we just discussed that Cloe is running on trinium so this is actually trinium what you see here is trinium generating code for trinium which I think is pretty cool with that I want to uh invite Joe from pul side to talk to you about the work they have done on top of trenum and trenum 2 CH can you hear me cool so hi everyone I'm Joe from poolside and I just want to give you some context about poolside and what it is that we do so poolside is a frontier AI company that is pursuing AGI by building AI models for software development now that's a very ambitious goal and from our side that demands Obsession and in particular that Obsession stretches through everything we do including the performance of our inference stack now at the moment our inference stack is primarily based on Pi torch admittedly mixed in with some Triton and some Cuda to give high performance the reason why I'm telling you this is because I view us as the typical non-trivial case with a case where you have existing code that needs to be ported there's some things that you need to change and I just want to talk to you about our experience with that in particular though there's two added things that has made porting for us perhaps slightly more difficult than in your case so firstly we are performance obsessed I'm like a lawyer I will debate every cycle and for us this means that we can't accept any slowdown at all in the speed of our inference the second thing is that we regularly try out new ideas in our inference stack and so for us if we had to reimplement everything multiple times that would be a non-starter I do want to stress though that even though I've just listed the barriers in the way it hasn't been a problem in practice when it's actually come to porting things to neuron and I want to give you some highle tips and tricks that we've learned along the way to help you if you decide to make the transition so oh sorry I should have done that so the first tip that I want to share with you is that you need to change your mindset so when we came to P our platform onto neuron we essentially took the P torch code that we had copied it over put it onto neuron and then we worked around the Triton and Cuda that we had now this mostly Works straight away which we were quite surprised urised at but of course there were some things that we needed to change and fix and fundamentally for us that tended to come down to where we were still operating under a different mindset so we were still thinking about writing code say for Cuda and it's not really quite the same when it comes to writing things for neuron I would liken this to learning a new language that is similar to one that you already know right you know for the first little while you're trying to build understanding you're stumbling around but eventually you get to a point where you understand how the system works and and for us I'd say that that took roughly about a week it really didn't take us that much time at all once we sat down and got into thinking about it but of course as with any language we learn every day and the process continues but I do want to point out that it's worth changing your mindset from what you think is fast to what is actually fast when it comes to neuron now the second tip that I want to give you is to use the tools that you're given so Gaddy just spoke about the neuron SDK in this case in particular I'm talking about using the libraries that provided in the neuron SDK so neuron X distributed and torch neuron X when we came to Port our code as I've already mentioned we just took our pie torch and copied it straight over and after we fixed some things it worked but we weren't necessarily getting the performance that we wanted and completely coincidentally we switched to use the tracing that's offered by one of those libraries that I mentioned and we got a 100% speed up across multiple devices now this wasn't something we expected we were trying to solve something completely different but it was obviously a performance thing that we took and I want to point out that by taking advantage of the platform in whatever way you can you do tend to get more out of the neuron Hardware in the language analogy from before this is like using idioms and catchphrases that you've learned from your new language right you aren't just doing direct translation you're picking up things that will help you contextually now this in particular exploiting the hardware is about writing code that is actually dedicated for neuron so when we started like I said we ported over with pytorch the pytorch tracing compiler is superb by the way it works most of the time we rarely run into hard limits that we can't beat but as a developer you know that there are situations where you can out optimize the compiler because you can make assumptions and decisions that it simply cannot and for us in this case we use the Nikki language which Gaddy just mentioned the important thing there is that you can write things in a pythonic flavor that takes direct advantage of the underlying Hardware you can specify operations that explicitly use certain engines or use certain Hardware features and you can profile it perfectly with the neuron profiler one of the really nice things about Nikki as I said is it's python desk essentially so you are able to pick it up very quickly if you know Triton or any other pythonic language you also don't need any low-level assembly experience you don't need any deep technical knowledge of like linking or anything like that you can simply write something that's very familiar and very concise and it's very very easy to pick up one of the nice things about Nikki in particular and also the neuron Hardware is that it enables you to express very succinctly ideas that you can't otherwise so for example in our case we've written some specialized kernels for doing the uh the token generation in the infant step and we these are very succinct kernels they may be 50 lines or so but you can get some something like a 100 time speed up in the inference step and a two time speed up in the prefill step so when you're processing The Prompt at the beginning and this is you know this is money that's lying on the ground right you have to pick it up I just want to stress that last Point again so something that we've noticed as we've been using this platform is that a great deal many kernels that you'll see are based on this top case here right they're considering Matrix Matrix operations which is great you can process your prompt very very quickly but actually we found that this is not always what you're spending your time doing when you're doing inference you're actually spending more of your time doing the second type of operation where you're taking a vector and you're multiplying it to a by a matrix or you're adding to a matrix or so on I won't go into the technicalities or the details of how we've done this but we've spent a great deal of time optimizing our kernels for their second case and just to give you an idea we're now in the unique position where the attention mechanism in our inference is no longer a bottleneck which is quite unusual and this is simply because we spent so much time tweaking our kernels and adapting them directly to the neuron Hardware with Nikki but of course you know we always like more hardware and as someone who's been working rather closely with the anera team we've been able to get early access to trainum 2 now I'll be honest with you switching to tranium 2 was a quite frankly magical experience for us so we just took took the code that we had we put it on a tranium 2 box and we got 50% more performance for free I didn't change anything and that's quite you know an unusual experience but there are other nice things about tranium 2 as well so firstly tranium 2 has 128 neuron cores in each instance which is four times as many as in a tranium one instance so right off the bat you can serve four times as many copies of your model per instance which for us was a big win the second thing that I really like about tranium 2 is tranium 2 has support for this thing called virtual course so I just said that tranium 2 has 128 physical neuron cores but you can group two of those at a time into one virtual core from the perspective of your software and this gives you a couple of advantages so firstly it means that you can present more memory to your application in this case in particular when you use the virtual core grouping you get 24 GB of HPM per virtual core which is more than the 16 G gab of hbm per neuron core that's offered in tranium one the second thing of course is that you get even more on device parallelism right you have two cores that are processing the same amount of work which is an amazing use case right as AI applications get larger and there tends to be more and more stuff that's going on having the ability to churn through all of that input faster is a complete GameChanger and even when you do that even when you group those two cores as one virtual core you still have twice as many instances per box box used to have 60 oh sorry 64 virtual cores per box which is twice as many as on tranium 1 and for us this is an unbelievably useful feature and it's something that we look forward to taking advantage of now I just want to show you some performance numbers I want to stress that these numbers are entirely provisional we've still got a lot of work left to do in terms of optimization on the left here you can see our numbers for trainum one so you can see that at a 2K context length we get about 26 tokens a second and that that drops down when you come to the 4K length in the middle there you have tranium 2 now this is just naively ported we haven't done any optimizations at all for Trum 2 on that middle set and you can see that you get that free roughly 50% increase in the 2K length slightly less in the 4K Contex length that's okay and this last column at the end is the elephant in the room now I've just spent all of this time talking to you about how you can optimize by using two cores as one virtual core but actually when we profiled our model we discovered that we weren't actually using both cores fully in every situation that we could now thankfully uh Nikki has this nice new feature coming called sharding which enables you to explicitly say hey I've got this input data in this Nikki kernel I want this core to process this and I want this core to process this and of course this just lets you explicitly specify what Paralis you want and how your data should be processed and as you can see this leads to a substantial performance increase even over over the regular trinium 2 code admittedly it drops down a little bit for the 4K context length but as I said this is primarily because we haven't truly finished optimizing yet and we look forward to continue to optimize this code for tranium 2 and with that I'll hand back over to Gaddy thank you Joe amazing amazing work by the poolside team and we are looking forward to continue partnering with you and getting to new performance sites this is just super exciting for me so we spoke a lot about py toch let's change gears and talk about Jacks so uh today with the launch of trainum 2 we actually also announced partnership with open source partnership with Google uh we are uh very proud to partner with Google around Jacks to enable more use cases and make sure code portability with Jacks allows customers to choose the right Hardware uh for their solution and of course uh uh the Jack's native integration of the open xlaa project which I mentioned before makes it really easy and I hope you'll see that very soon when Matt goes on stage um there are a few other customers that already using Jacks the anthropic team that we that we talked about quite a lot today they actually started using Jacks over a year ago and they have seen great results in moving and porting their use their their code uh from from the platforms they were using onto tranium very easily because of the code portability that is built in natively into Jacks and with that I'm very happy to invite Matt to Stage he's actually the original author of Jack started as a I guess a skunk project within Google and then grew to the amazing framework it is today I'm very proud to accept to the stage Matt hey thanks got yeah all right sound okay all right yeah so I'm Matt uh I work at Google DM mind and I'm here to tell you just a little bit about um Jax so uh this is Jax in one slide Jax is a python Li library for numerical Computing and especially large scale machine learning um you can see on the left an example of Jack's code for implementing just a very trivial fully connected neural network um and you can see it looks just like writing Jack's uh Jack's Hood looks just like writing numpy code so behind the scenes uh when you write Jack's numpy functions when you uh call them JX uh uses the open xcla stack and the xcla compiler to run those uh functions on CPUs gpus tpus and now also on tranium um and so Jax was built from the ground up to use xla and uh you know being able to run numpy functions on accelerator Hardware this easily uh would have been a dream for me in grad school and so I'm glad that it exists now um so that's sort of the first half of Jax is like oh just you know run numpy on any accelerator hard Hardware you want that sounds pretty great uh already the second half of Jack where it really starts to shine is to do with uh what we think of as function Transformations these colorful grad jit and vmap things I've highlighted here um so function Transformations are things like grad for autod so if you want to train your uh Network use autod grad as our function transformation uh for doing that jit for adjus in time compilation is where you can say I want to Stage out this entire function uh to the compiler not just compile individual jackai functions but I want to Stage out end to end the function that I wrote for my neural network prediction function or loss function or Optimizer update step and by staging out to the compiler not only can the compiler do optimizations like uh deciding fusions or you know layout optimizations but this is also what gives Jax the power to scale from one uh accelerator accelerator chip to many thousands um so people have been able to do a lot of different things uh with Jacks we've heard about some of them so uh as you may know Google deep mines all of uh the large language model and large multimodal uh model work is done with Jax also other companies like anthropic and uh Apple intelligence but also uh like xai on on GPU uh they use Jack to power a lot of their large scale machine learning work um but there's also work on image and and video generation from Google like uh imin and uh mid journey and work uh from Google Deep Mind on mathematical reasoning these sort of things so Jax helps power um all these a bit beyond standard machine learning uh Jack also helps power uh a lot of scientific Computing so for example Alpha fold 3 for protein prediction and Nobel Prize Generation Um it's also used in uh robotics for things like simulation on accelerators and for things like climate modeling um large scale climate modeling and finally Jax is used all over in academic research from synthetic biology to astronomy to things like uh you know Cutting Edge uh AI research so how exactly does Jack uh help all these people and how does it power all these these things um it really comes down to the Jax gives you AGI and by that I mean a simple composable functional API um paired with the fact that it's geared towards compilation and scaling up just uh uh from the beginning um and inherent portability across all of these platforms and to try to show you some of those things and give you a concrete taste I'm going to attempt a live demo let's see if we can pull this off all right hey great so um this is a tranium instance it's tranium one because I prepared this before this morning uh but I'm excited to try this on tranium 2 um let's just uh you know I started this this uh notebook let's just start running numpy operations so here we've imported Jack's numpy uh and we're you know making a big Matrix and multiplying it together and printing out some results that's great and you might not even know that this is running behind the scenes on uh on a tranium chip but if you look uh you can see that yes the data indeed live on uh uh tranium uh neuron core here um and you know just using these kind of operations like these numpy operations uh we can start to write neural network so here's our toy neural network from our our first slide along with some parameter initialization code and just some you know synthetic random data um and we can start to run it we can say you know let's uh run a prediction step let's evaluate our loss function um so that's just running Jackson umpai on neuron uh or any other uh platform you want um if you want to start to Stage things out uh to really start to get performance you can use jack. jit that's our our decorator for Jus in time compilation you apply it to a python callable gives you a new python colable that is uh end to end stage out the entire thing sort of whole graph capture if you like so here we're running uh you know our prediction function and our loss function and up to floating point tolerances you can see we're Computing the same numerical function um but now it's been it's been staged out where jet really starts to shine is when we do parallelization through automatic partitioning of computation let's take a look at that so we were actually just running computations on one neuron core but that's not making good use of uh this machine that I running because we have 32 neuron cores available um and moreover when you're running you might want to scale out to hundreds or thousands of uh of of chips so uh let's look at how Jax has one unified approach to expressing batch data parallelism as well as any other kind of parallelism you might like um so here's a simple batch data parallelism example we're going to take our devices and we're going to say let's make a mesh out of them in this case we're going to start with just a simple onedimensional mesh of our 32 devices and we're going to call that mesh of devices uh in a line we're going to say uh let's name that that uh single axis the batch axis and then we're just going to put our data our inputs and our our targets um you can think of this our data is two- dimensional right and the batch axis is the leading AIS and so we're saying take a chunk of uh of inputs a chunk along our batch and give each chunk to one of these neuron cores so this is just a simple visualization of our 2D array of data and even though it appears like one uh array to Jax we can see that it's backed by memory uh across all these different devices we've given one chunk of that uh data to each of these different neuron cores and now we can do the same thing we did before we can compute the loss function and we can run predictions um and this is actually being parallelized for us behind the scenes in this case the parallelization is really simple just batch data parallelism um but it's just happening behind the scenes uh with the compiler it has to recompile for every new uh every new sharting of course and we can see in this case this is just simple bat parallel case indeed our predictions are coming out laid out in a parallel manner as well that's just giving you some indication that indeed we are uh keeping the data uh uh distributed uh and partitioning the computation to be to be parallel um let's start mixing that in with more so here we're going to use Jack grad for autod diff to just do some training of our our toy neural network so you apply jack. grad to the loss function that gives you a new function uh that computes the gradients with respect to the first um argument so if we run this remember this is running in batch data parallel because we just made our data distributed and you can see we get we get gradients out if you think about it actually the compiler is doing something non-trivial here because to get the gradients we have it has to automatically insert an all reduce uh to to compute these results but we don't have to worry about that we just get to set the sharting and it's being parallelized automatically for us so here's a simple update step this is uh you know simple gradient descent on our um uh on our parameters we can see we're doing grad underneath our our jit and we're just going to Loop calling this many times um so this whole update step is being sent to the compiler it can be end to end optimized and we can see that now this is going to be you know loss goes down uh that's great maybe we can have a bigger step size after all loss keeps going down that's great that's that's AI um so I think I might I might skip this section uh just uh uh but briefly you can even compose more Jack Transformations together this really starts to give you a lot of expressive power in this case um we can use vmap which does efficient vectorized um batching of computations to do things like compute per example gradients so if you're doing some computations for example differential privacy maybe you want per example gradients Jax makes that super easy to express and of course you can compose that with autod diff and with nend to end just in time compilation um and automatic par paration um so as one final uh example I'll just show how we can uh try different parallelization strategies so in this case we're going to take our mesh of 32 devices we're going to make it two-dimensional we're going to say I actually want to do some uh model parallelism some tensor par parallelism and also do some batch data parallelism at the same time so we're going to take our uh 32 devices we're going to think of it as a 16x2 uh mesh and then we're going to place our data so that our um our uh inputs uh our input data is sharded along its batch Dimension and along its feature Dimension and the input parameters are also sharted along their sort of leading leading feature Dimension uh in this way and I'm going to make one change to our example code as well just as a hint to the compiler to say how we want to keep things uh charted I'm going to annotate this intermediate with sharding constraint to say I want the results of my mmols to stay sharded in this way where I'm doing both Bach parallelism and tensor parallelism uh and with that change behind the scenes um let's see what our loss value was before 415 339 something and now you know now we're getting the this result out uh now instead of just batch dat parallelism we're doing batch and model parallelism so um that's how JX gives us one uh unified mechanism to do parallelism um and to scale out uh from one chip to you know 32 cores in in that case to uh many thousands so that was just a brief introduction on Jax um and why it might be useful to you and I think with that uh I'll end and hand it back to Gotti thank you very much Matt this was awesome you guys there is a test on the way out just to make sure you got everything here this was pretty awesome I think the main point is that the code is portable if you uh if you are running on training you use the same code if you run on GPU it will be the same code of course different sharding dimensions because uh you have eight gpus in instance you have 16 trums in a tranium 2 instance and things of that nature but largely the code is portable which is awesome to see thank you Matt and last thing I want to talk about this is going to be Jack is going to be very impactful for the customers I mentioned anthropic and the apple and many and hopefully many other customers will tell ad adob Jacks I want to thank all of our trainum to launch customers and Partners H we have wouldn't be able to continue and develop tranum neuron and this amazing technology without our wonderful customers and partners making sure it's easy to use and adopt and uh evolving over time to more performance and more scale so huge thank you to all of our customers and partners last thing I want to highlight to you guys there is a lot there are a lot of sessions this year around trainum in frenia h these are the U it's a bit of a shout out to the rest of the team that is working really hard at reinvent we have more than 30 sessions this week for you this is only part of them so please scan the QR code you'll get the full list including Hands-On sessions on trainum in influencia I wish you guys thank you for coming I wish you all a great reinvent and uh see you soon thank you

2024-12-20 12:29

Show Video

Other news

AMD BC-250 Обзор и запуск игр. Играем на чипе PlayStation 5. Simple guide how to run games on BC-250 2025-01-13 23:44
The CATL Finally Released The SOLID STATE Batteries and Will Shock the Entir Industry! 2025-01-12 23:32
Linus's "10 Rules" & How I Use Them In My Own Videos 2025-01-10 07:23