Abstract
The applications of machine learning are undeniably widespread. Many projects I have done in the past required some form of learning or neural network. My preference for experimentation lies in the Java programming language. However, Java does not have any commonly used machine learning libraries. I have thus repeatedly implemented artificial neural network frameworks across various different projects. This project is my solution to this problem. I submit a Java library containing highly generalized machine learning structures to facilitate future projects in a modular, efficient, and simple way.
Neural Network Architecture
The goal of machine learning is to estimate a function by optimizing certain parameters with respect to a loss function. The most common class of function used in machine learning is a matrix multiplication. When given an input vector, the output of the function is created by performing a matrix multiplication. The parameters to optimize are the entries of the matrix. This is called linear regression because the output is a linear function of the inputs.
In many cases, the function we would like to estimate is not linear. The class of linear functions is therefore not sufficient. To resolve this issue, non-linear functions are introduced. Some commonly used functions include the rectified linear unit (ReLU), the logistic function, the softplus function, and the swish function. There is also the softmax function, which normalizes a vector to a probability distribution. These non-linear functions are called activation functions. To create a neural network, inputs are passed through multiple different types of functions in sequence. Standard networks alternate between the linear regression layer and an activation layer.
Learning and Optimization
We start with the classification problem. We would like our network to assign a class to each input item. The output of the network would then be a probability distribution across all possible classes. Optimization also requires a loss function. The most common loss function for networks that use the softmax activation function is the cross entropy loss, or the expected surprise for the correct output class across all items passed through the network.
This loss function is easy to differentiate. Using the derivative of the loss function with respect to the outputs of the neural network, backpropagation can be applied to find the derivatives of the loss function with respect to network parameters. A simple gradient descent algorithm can be applied to find a set of parameters with low loss.
Weight Initialization
Various techniques exist to improve the results of network optimization. A good way to improve convergence rate is to regulate variance across the network layers. The derivative of the loss function with respect to different network layers directly contributes to how parameters are updated. If these derivatives increase or decrease too much as they backpropagate, then the changes to the network parameters may be too extreme.
The linear layers are the biggest cause of this issue. If the parameters of the linear layers are initialized with zero mean and unit variance, then backpropagation will scale derivatives by a factor proportional to the number of inputs. To fix this issue, we modify the variance of weight initialization to be inversely proportional to the network inputs.
Optimization Techniques
An update step when using stochastic gradient descent only applies information about the derivative evaluated during the current single iteration. Variations of stochastic gradient descent attempt to improve convergence rate and avoid local minima using information about previous iterations.
- Momentum models the motion of a spherical object rolling down the loss landscape. A small fraction of past update steps are reapplied at each iteration.
- Adaptive gradient (AdaGrad) changes the direction of the gradient update to favor step directions that have not been explored often. RMSProp improves upon this idea.
- Adam merges the ideas of momentum and RMSProp. This algorithm is the most common choice for general machine learning applications.
In my experimentation, all algorithms performed very similarly. However, I did not spend much time finding optimal hyperparameters, so I will not make any conclusions about the effectiveness of these variations.
Runtime
Machine learning applications use GPUs to speed up processing times. My computer does not have a GPU, but there are still ways to improve the runtime of the library functions.
- When writing code, sacrificing efficiency for readability may be beneficial. Adding intermediate variables or extra loops can make debugging easier. However correctness has been confirmed, it leaves room for much optimization.
- Developers must consider the tradeoff between generality and efficiency when designing a library. In this library, we consider the general case of multivariable functions. However, a common use case is to use multiple scalar functions in parallel. Speeding up the common case has a huge impact on the overall runtime, so writing extra code to handle this specific case is useful.
- GPUs are used because they can run multiple tasks in parallel. CPUs can also apply parallelism through multithreading, although at a much smaller scale. The speedup from adding multithreading easily outperforms all other improvements made to the program.
Applying these changes led to a 3x increase in speed. Although negligible when compared to using a GPU, the improvement is definitely welcome. I hope to explore the GPU applications in the future.
Normalization
Normalizing data is a common way to improve the performance of a network. Up until this point, I manually rescaled the input data to a zero mean and unit variance. It would be beneficial for the network to normalize these by itself. Normalization layers can also be added in the middle of a large network to regularize the mean and variance of intermediate layer outputs.
Batch normalization normalizes a value by taking the mean and variance across a batch of input data. During training time, mini-batch gradient descent is used, so every pass through the batch normalization layers can calculate a non-trivial mean and variance. During test time, items may be passed one at a time. To resolve this issue, batch normalization tracks an exponential average of the mean and variance values encountered during training. These running values will be used during test time.