Training Models for Machine Learning
As we presented in our previous Authoritative Guide to Data Labeling, machine learning (ML) has revolutionized both state of the art research, and the ability of businesses to solve previously challenging or impossible problems in computer vision and natural language processing. Predictive models, trained on vast amounts of data, now have the ability to learn and detect patterns reliably, all without being specifically programmed to execute those tasks.
More broadly, ML models can predict numerical outcomes like temperature or a mechanical failure, recognize cars or retail products, plan better ways to grasp objects, and generate useful and helpful, salient and logical text, all without human involvement. Want to get started training and building models for your business use case? You’ve come to the right place to learn how model training works, and how you too can start building your own ML models!
What Are Machine Learning (ML) Models?
ML models typically take “high-dimensional” sets of data artifacts as inputs and deliver a classification, a prediction, or some other indicator as an output. These inputs can be text prompts, numerical data streams, images or video, audio, or even three-dimensional point cloud data. The computational process of producing the model output is typically called “inference,” a term adopted from cognitive science. The model is making a “prediction” based on historical patterns.
What distinguishes a ML model from simple heuristics (often conditional statements) or hard-coded feature detectors (yes, face recognition used to depend on detecting a specific configuration of circles and lines!) is a series of “weights,” typically floating point numbers, grouped in “layers,” linked by functions. The system is trained through trial and error, adjusting weights to minimize error (a metric typically referred to as “loss” in the ML world) over time. In nearly all ML models, there are too many of these weights to adjust them manually or selectively; they must be “trained” iteratively and automatically, in order to produce a useful and capable model. Ideally, this model has “learned” on the training examples, and can generalize to new examples it hasn’t seen before in the real world.
Because these weights are iteratively trained, the ML engineer charged with designing the system in most cases can only speculate or hypothesize about the contribution of each individual weight to the final model. Instead, she must tweak and tune the dataset, model architecture, and hyperparameters. In a way, the ML engineer “steers the ship” rather than micromanaging the finest details of the model. The goal after many rounds of training and evaluation (known as “epochs”) is to induce the process to reduce model error or loss (as we mentioned above) closer and closer to zero. Typically when a model “converges,” loss decreases to a global minimum where it often stabilizes. At this point, the model is deemed “as good as it’s going to get,” in the sense that further training is unlikely to yield any performance improvements.
Sometimes it’s possible to detect that a model’s performance metrics have stabilized and engage a technique known as “early stopping.” It doesn’t make sense to spend additional time and compute spend on additional training that doesn’t meaningfully improve the model. At this stage, you can evaluate your model to see if it’s ready for production or not. Real-world user testing is often helpful to determine if you’re “ready” to launch the product that encapsulates your model, or you need to continue tweaking, adding more data, and re-training. In most applications, externalities will cause model failures or drift, requiring a continued process of maintenance and improvement of your model.
Divvying up your data
In order to train a model that can properly “generalize” to data it has never seen before, it’s helpful to train the model on most on 50-90% of available data, while leaving 5-20% out in a “validation” set to tune hyperparameters, and then also save 5-20% to actually test model performance. It’s important not to “taint” or “contaminate” the training set with data the model will later be tested on, because if there’s identical training assets between train and test, the model can “memorize” the result, thereby overfitting on that example, compromising its ability to generalize, which is typically an important attribute of nearly every successful ML model. Some researchers refer to the test set (that the model has never seen before) as the “hold-out” or “held out” set of data. You can think of this as the “final exam” for the model, which the model shouldn’t have seen verbatim before exam day, even if it has seen similar examples in the metaphorical problem sets (to which it has checked against the answer key) during prior training.
What types of data can ML models be trained on?
If you’re simply interested in computer vision, or some of the more sophisticated and recent data types that ML models can tackle, skip ahead to the Computer Vision section. That said, working with tabular data is helpful to understand how we arrived at deep learning and convolutional neural networks for more complex data types. Let’s begin with this simpler data type.
Tabular data typically consists of rows and columns. Columns are often different data types that correspond to each row entry, which might be a timestamp, a person, a transaction, or some other granular entry. Collectively, these columns can serve as “features” that help the model reliably predict an outcome. Or, as a data scientist, you can choose to multiply, subtract, or otherwise combine multiple columns and train the model on those combinations. For tabular data, there are a wide variety of possible models one can apply, to predict a label, a score, or any other (often synthesized) metric based on the inputs. Often, it’s helpful to eliminate columns that are “co-linear,” although some models are designed to deprioritize columns that are effectively redundant in terms of determining a predictive outcome.
Tabular data continues the paradigm of separating training and test data, so that the model doesn’t “memorize” the training data and overfit—regurgitate examples it has seen, but fumble its response to ones it hasn’t. It even enables you to dynamically shift the sections of the table that you’ll use (best practice is to randomize the split, or randomize the table first) such that test, train, and evaluation sets are all in windows that can be slid or swapped across your dataset. This is known as cross-fold validation, or n-folds validation, where n represents the number of times the table is “folded” to divvy up your training and test sets in different portions of the table.
A final point about tabular data that we’ll revisit in the computer vision section is that data often has to be scaled to a usable range. This might mean scaling a range of values from 1 to 1000000 into a floating point number such that the range is between 0 and 1.0 or -1.0 and 1.0. Machine Learning engineers often need to experiment with different types of scaling (perhaps logarithmic is also useful for some datasets) in order for the model to reach its most robust state—generating the most accurate predictions possible.
No discussion of machine learning would be complete without discussing text. Large language models have stolen the show in recent years, and they generally function to serve two roles:
- Translate languages from one to another (even a language with only minimal examples on the internet)
- Predict the next section of text—this might be a synthesized verse of Rumi or Shakespeare, a typical response to a common IT support request, or even an impressively cogent response to a specific question
In addition to deep and large models, there are a number of other techniques that can be applied to text, often in conjunction with large language models, including unsupervised techniques like clustering, principal components analysis (PCA), and latent dirichlet allocation (LDA). Since these aren’t technically “deep learning” or “supervised learning” approaches, feel free to explore them on your own! They may prove useful in conjunction with a model trained on labeled text data.
For any textual modeling approach, it’s also important to consider “tokenization.” This involves splitting words in the training text into useful chunks, which might be individual words, full sentences, stems, or even syllables. Although it’s now quite old, the Python-based Natural Language Toolkit (NLTK) includes Treebank and Snowball tokenizers, which have become industry standard. SpaCy and Gensim also include more modern tokenizers, and even PyTorch, a cutting-edge, actively developed Python ML library includes its own tokenizers.
But back to large language models: it’s typically helpful to train language models on very large corpuses. Generally speaking, since text data requires much less storage than high resolution imagery, you can either train models on massive text datasets (generally known as “corpuses,” or if you’re a morbid scholar of Latin, maybe you can write corpora) such as the entire Shakespeare canon, every Wikipedia article ever written, every line of public code on GitHub, every public-domain book available on Project Gutenberg, or if you’d rather writing a scraping tool, as much of the textual internet as you can save on a storage device quickly accessible by the system on which you plan to train your model.
Large language models (LLMs) can be “re-trained” or “fine tuned” on text data specific to your use case. These might be common questions and high quality responses paired with each question, or simply a large set of text common to the same company or author, from which the model can predict the next n words in the document. (In Natural Language, every body of text is considered a “document”!) Similarly, translation models can start with a previously trained, generic LLM and be fine-tuned to support the input-output use case that translation from one language to another requires.
Images are one of the earliest data types on which researchers trained truly “deep” neural networks, both in the past decade, and in the 1990s. Compared to tabular data, and in some cases even audio, uncompressed images take up a lot of storage. Image size not only scales with the width and height (in pixels) but also in color depth. (For example, do you want to store color information or only brightness values? In how many bits? And how many channels?) Handwritten digit detection is one of the simpler images to detect as it requires only comparing binary pixel values at relatively low resolution. In the 1990s, most neural networks were laboriously trained on large sequential computers to adjust the model weights so that handwriting recognition (in the form of Pitney Bowes’ postal code detector in conjunction with Yann LeCun’s LeNet network), and MNIST’s DIGITS dataset is still considered a useful baseline computer vision dataset for proving a minimal baseline of viability for a new model architecture.
More broadly speaking, images include digital photographs as most folks know and love them, captured on digital cameras or smartphones. Images typically include 3 channels, one each for red, blue, and green. Brightness is encoded usually in the form of 8- or 10-bit sequences for each color channel. Some ML model layers will simply look at brightness (in the form of a grayscale image) while others may learn their “features” from specific color channels while ignoring patterns in other channels.
Many models are trained on CIFAR (smaller, 10 or 24 classes of labels), ImageNet (larger, 1000 label classes), or Microsoft’s COCO dataset (very large, includes per-pixel labels). These models, once reaching a reasonable level of accuracy, can be re-trained or fine-tuned on data specific to a use case: for example, breeds of dogs and cats, or more practically, vehicle types.
Video is simply the combination of audio and a sequence of images. Individual frames can be used as training data, or even the “delta” or difference between one frame and the next. (Sometimes this is expressed as a series of motion vectors representing a lower-resolution overlay on top of the image.) Generally speaking, individual video frames can be processed with a model just like an individual image frame, with the only difference that adjacent frames can leverage the fact that there might be overlap between a detected object in one frame and its (sometimes) nearby location in the next frame. Contextual clues can assist per-frame computer vision, including speech recognition or sound detection from the paired audio track.
Sound waves, digitized in binary representation, are useful not just for playback and mixing, but also for speech recognition and sentiment analysis. Audio files are often compressed, perhaps with OGG, AAC, or MP3 codecs, but they typically all decompress to 8, 16, or 24 bit amplitude values, with sample rates anywhere between 8 kHz and 192 kHz (typically at multiples-of-2 increments). Voice recordings generally require less quality or bit-depth to capture, even for accurate recognition, and (relatively) convincing synthesis. While traditional speech to text services used Hidden Markov Models (HMMs), long short-term memory networks (LSTMs) have since stolen the spotlight for speech recognition. They typically power voice-based assistants you might use or be familiar with, such as Alexa, Google Assistant, Siri, and Cortana. Training text to speech and speech to text models typically takes much more compute than training computer vision models, though there is work ongoing to reduce barriers to entry for these applications. As with many other use cases, transformers have proven valuable in increasing the accuracy and noise-robustness of speech-to-text applications. While digital assistants like Siri, Alexa, and Google Assistant demonstrate this progress, OpenAI’s Whisper also demonstrates the extent to which these algorithms are robust to mumbling, background interference, and other obstacles. Whisper is unique in that it can be called via API rather than used in a consumer-oriented end product.
3D Point Clouds
Point clouds encode the position of any number of points, in three-dimensional Cartesian (or perhaps at the sensor output, polar) space. What might at first look like just a jumble of dots on screen typically can be “spun” on an arbitrary axis with the user’s mouse, revealing that the arrangement of points is actually three-dimensional. Often, this data is captured by laser range-finders that spin radially, mounted on the roof of a moving vehicle. Sometimes a “structured light” infrared camera or “time-of-flight” camera can capture similar data (at higher point density, usually) for indoor scenes. You may be familiar with similar cameras thanks to the Wii game console remote controller (Wiimote), or the Xbox Kinect camera. In both scenarios, sufficiently detailed “point clouds” can be captured to perform object recognition on the object in frame for the camera or sensor.
It turns out the unreasonable effectiveness of deep learning doesn’t mandate that you only train on one data type at a time. In fact, it’s possible to ingest and train your model on images and audio to produce text, or even train on a multi-camera input in which all cameras capturing an object at a single point in time are combined into a single frame. While this might be confusing as training data to a human, certain deep neural networks perform well on mixed input types, often treated as one large (or “long”) input vector. Similarly, the model can be trained on pairs of mismatched data types, such as text prompts and image outputs.
What are some common classes or types of ML models?
Support Vector Machines (SVMs)
SVMs are one of the more elementary forms of machine learning. They are typically used as classifiers (yes, it’s OK to think of “hot dog” versus “not hot dog,” or preferably cat versus dog as labels here), and while no longer the state of the art for computer vision, they are still useful for handling “non-linear” forms of classification, or classification that can’t be handled with a simple linear or logistic regression. (think nearest neighbor, slope finding on a chart of points, etc.) You can learn more about SVMs on their documentation page at the scikit-learn website. We’ll continue to use scikit-learn as a reference for non-state-of-the-art models, because their documentation and examples are arguably the most robust available. And scikit-learn (discussed in greater depth below) is a great tool for managing your dataset in Python, and then proving that simpler, cheaper-to-train, or computationally less complex models aren’t suitable for your use case. (You can also use these simpler models as baselines against which you can compare your deep neural nets trained on Scale infrastructure!)
Random Forest Classifiers
Random Forest Classifiers have an amazing knack of finding just the right answer, whether you’re trying to model tabular data with lots of collinearities or you need a solution that’s not computationally complex. The “forest” of trees is dictated by how different buckets of your data impact the output result. Random Forest Classifiers can find non-linear boundaries between adjacent classes in a cluster map, but they typically don’t track the boundary perfectly. You can learn more about Random Forest Classifiers, again, over at scikit-learn’s documentation site.
Gradient-boosting is another technique/modification applied to decision trees. These models can also handle co-linearities, as well as a number of hyperparameters that can limit overfitting. (Memorizing the training set so that training occurs with high accuracy, but does not extend to the held-out test set.) XGBoost is the framework that took Kaggle by storm, and LightGBM and Catboost also have a strong following for models of this class. Finally, this level of model complexity begins to depart from some of the models in scikit-learn that derive from simpler regressions, as hyperparameter count increases. (Basically there are now more ways in which you can tune your model, increasing complexity.) You can read all about how XGBoost works here. While there are some techniques to attribute model outputs to specific columns or “features” on the input side, with Shapley values perhaps, XGBoost models certainly demonstrate that not every ML model is truly “explainable.” As you might guess, explainability is challenged by model complexity, so as we dive deeper into complex neural networks, you can begin to think of them more as “black boxes.” Not every step of your model will necessarily be explainable, nor will it necessarily be useful to hypothesize why the values of each layer in your model ended up the way they did, after training has completed.
Feedforward Neural Networks
These are the simplest form of neural networks in that they accept a fixed input length object and provide classification as an output. If the model extends beyond an input layer and an output layer (each with its own set of weights), it might have intermediate layers that are termed “hidden.” The feedforward aspect means that there are no backwards links to earlier layers: each node in the graph is connected in a forward fashion from input to output.
Recurrent Neural Networks
The next evolution in neural networks was to make them “recurrent.” That means that nodes closer to the output can conditionally link back to nodes in earlier layers—those closer to the input in the inference or classification pipeline. This back-linking means that some neural networks can be “unrolled” into simpler feedforward networks, whereas some connections mean they cannot. The variable complexity resulting from the classification taking loops or not means that inference time can vary from one classification run to the next, but these models can perform inference on inputs of varying lengths. (Thus not requiring the “fitting” mentioned above in the tabular data section.)
Convolutional Neural Networks
Convolutional Neural Nets have now been practically useful for roughly 10 years, including on higher resolution color imagery, in large part thanks to graphics processors (GPUs). There’s much to discuss with CNNs, so let’s start with the name. Neural nets imply a graph architecture that mimics the interconnected neurons of the brain. While organic neurons are very much analog devices, they do send signals with electricity much like silicon-based computers. The human vision system is massively parallel, so it makes sense that a parallel computing architecture like a GPU is properly suited to compute the training updates for a vision model, and even perform inference (detection or classification, for example) with the trained model. Finally, the word “convolution” refers to a mathematical technique that includes both multiplication and addition, which we’ll describe in greater detail in a later section on model layers.
Long Short-Term Memory Networks (LSTMs)
If you’ve been following along thus far, you might notice that most models you’ve encountered up until this point have no notion of “memory” from one classification to the next. Every “inference” performed at runtime is entirely dependent on the model, with no influence from any other inference that came before it. That’s where “Long Short Term Networks” or LSTMs come in. These networks have a “gate” at each node in the network that allows the weights to remain unchanged. This can sometimes mitigate the “vanishing gradient” problem in RNNs, in which all weights in a layer might “vanish” to 0 if a change of the weight is mandatory in every epoch (a single step-wise update of all of the weights in the model based on the output loss function) of the training run. Because weights can persist for many epochs, this is similar to physiological “short term memory,” encoded in the synaptic connections in the human brain: some are weakened or strengthened or left alone as time passes. Let’s turn back to applications from theory, though: some earlier influential LSTMs became renowned for their ability to detect cat faces in large YouTube-derived datasets like YouTube-8M. Eventually the model could operate in reverse, recalling the rough outline of a cat face, given the “cat” label as an input.
Q-Learning is a “model-free” approach to learning, using reinforcement with an “objective function” to guide updates to the layers in the network. DeepMind instituted this process in their famous competition against world champion Lee Sedol at the game of Go. Q-Learning has since shown to be incredibly successful at learning other historically significant Atari Games, as well as RPG strategy games like WarCraft and StarCraft.
—OpenAI and MuJoCo, from OpenAI Gym
Thus far, we haven’t spent much time or words on text models, so it’s time to begin with Word2Vec. Word2Vec is a series of models that matches word pairs with cosine similarity scores, so that they can be productively mapped in vector space. Word2Vec can produce a distributed representation of words or a continuous “bag-of-words,” or a continuous “skip-gram.” CBOW is faster, while “skip-gram” seems to handle infrequent words better. You can learn more about word2vec in the documentation for the industry-standard gensim Python package. If you’re looking to synthesize biological sequences like DNA, RNA, or even proteins, word2vec also proves useful in these scenarios: it handles sequences of biological and molecular data in the same way it does words.
After LSTMs and RNNs reigned as the state of the art for natural language processing for several years, in 2017, a group of researchers at Google Brain formulated a set of multi-head attention layers that began to perform unreasonably well on translation workloads. These “attention units” typically consist of a scaled dot product. The only drawback to this architecture was that training on large datasets and verifying performance on long input strings during training was both computationally intensive and time-consuming. Attention(Q,K,V), or attention as a function of the matrices Q, K, and V, is equivalent to the softmax(QKT/sqrt(dk))*V. Modifications to this approach are typically focused on reducing the computational complexity from O(N2) to O(N ln N) with Reformers, or to O(N) with ETC or BigBird, where N is the input sequence length. The larger, in the case of BERT, “teacher” model is typically a product of self-supervised learning, starting with unsupervised pre-training on large Internet corpuses, followed by supervised fine-tuning. Common tasks that transformers can reliably perform include:
- Question Answering
- Reading Comprehension
- Sentiment Analysis
- Next Sentence Prediction/Synthesis
As the authors of this model class named their 2017 research paper, “Attention Is All You Need.” This title was prescient, as transformers are now a lower-case category, and have influenced vision systems as well (known as Vision Transformers, or ViT for short), CLIP, DALL-E, GPT, BLOOM, and other highly influential models. Next, we’ll jump into a series of specific and canonically influential models—you’ll find the transformer-based models at the end of the list that follows.
What are some commonly used models?
AlexNet was the model that demonstrated that compute power and convolutional neural nets could scale to classify as many as 1000 different classes in Stanford’s canonical ImageNet dataset (and corresponding image classification challenge). The model consisted of a series of activation layers, hidden layers, ReLUs, and some “pooling” layers, all of which we’ll describe in a later section. It was the first widely reproduced model to be trained on graphics processors (GPUs), two NVIDIA GTX 580s, to be specific. Nearly every successor model past this point was also trained on GPUs. AlexNet won the 2012 N(eur)IPS ImageNet challenge, and it would become the inspiration for many successor state-of-the-art networks like ResNet, RetinaNet, and EfficientDet. Whereas predecessor neural networks such as Yann LeCun’s LeNet could perform fairly reliable classification into 10 classes (mapping to the 10 decimal digits), AlexNet could classify images into any of 1000 different classes, complete with confidence scores.
Residual Networks—ResNet for short—encompass a series of image classifiers of varying depth. (Depth, here, roughly scaling with classification accuracy and also compute time.) Often, while training networks like AlexNet, the model won’t converge reliably. This “exploding/vanishing” (“vanishing” was explained above, while “exploding” means the floating point value rapidly increases to the maximum range of the data type) gradient problem becomes more challenging as the ML engineer adds more layers or blocks to the model. Compared to previous “VGG” nets of comparable accuracy, designating certain layers as residual functions drastically reduces complexity, enabling models up to 152 layers deep that still have reasonable performance in terms of inference time and memory usage. In 2015, ResNet set the standard by winning 1st place in ImageNet and COCO competitions, for detection, localization, and segmentation.
Single Shot MultiBox Detector (SSD)
Single Shot MultiBox Detector (SSD) was published at NeurIPS in 2015 by Wei Li, a member of Facebook AI Research. Written in Caffe, it provided an efficient way for neural networks to also identify bounding boxes for objects, and dynamically update the object detection bounding boxes dynamically. While AlexNet proved the value of image classification to the market, SSD paved the way for increasingly accurate forms of object detection, even at high frame rates, as can be found on a webcam (60 frames per second or “FPS”), autonomous vehicle, or security camera (often lower, perhaps 24 FPS).
Faster R-CNN and Mask R-CNN
It’s helpful to have bounding boxes (rectangles, usually) to identify objects in an image, but sometimes even greater levels of detail can be useful. Happily, it’s possible to train models on datasets that include images and matching per-pixel labels. Faster R-CNN (published by Matterport, a residential 3D scanning company) and its spiritual successor, Mask R-CNN are notable and iterative models that generate robust pixel-wise predictions in a relatively short amount of time. These models can be particularly useful for robotics applications such as picking objects out of boxes, or re-arranging objects in a scene. Mask R-CNN was published under the name “Detectron” on GitHub by Facebook AI Research. We’ll cover its successor, Detectron2, below.
You Only Look Once (YOLO) and YOLOv3
Along the lines of SSD, mentioned above, Joseph Redmon at University of Washington decided to eschew the then-burgeoning TensorFlow framework in favor of hand-coding another “single shot” object detector with the goal of making it run extremely fast on C and CUDA (GPU programming language) code alone. His model architecture lives on in the form of Ultralytics, a business organized around deploying YOLOv5, now in PyTorch, models (currently) to customers. YOLO is an architecture that has stood the test of time and is somewhat relevant today, pushing the barriers of high-quality and high-speed object detection.
Inception v3 (2015), RetinaNet (2017) and EfficientDet (2020)
In the decade that has passed since Alex Krizhevsky published AlexNet at the University of Toronto, every few years a new model would win in the annual ImageNet and MSCOCO challenges. Today, it may seem as though high-accuracy, high-speed object detection is a “solved” problem, but of course there is always room to discover smaller, simpler models that might perform better on speed, quality, or some other metric. There have been some steps forward in the state-of-the-art for objective detection based on “neural architecture search,” or using a model to select different configurations and sizes and types of layers. Yet, today’s best models learn from those experiments but no longer explicitly search for better model configurations with an independent ML model.