Paper Link : https://arxiv.org/abs/1506.02626
Github : x
Introduction
현재 MIT 조교수로 재직중인 Song Han이 Stanford에서 박사과정 중에 쓴 network pruning paper입니다. 그가 제시한 방법은 다음과 같습니다.
Strategy
1. Pre-training network
2. Prune weight
3. Fine-tune
4. Iterate 2-3
우선 어떤 weight이 중요한 지 알기 위해 neural network를 training 시킵니다. 그리고 중요하지 않은 weight을 일정 threshold가 넘으면 남겨두고, 넘지 않으면 pruning 시켜줍니다. 그리고 남은 network에 대해서 Fine-tuning을 해주고, 다시 pruning을 반복합니다. 즉 weight들 중에 중요한 옥석들만 골라내겠다는 것입니다. 논문에서는 정확히 어떤 값을 threshold로 사용했는 지 밝히지 않고 있으나, median, mean 등이 쓰일 수 있을 것 같습니다.
Dropout Ratio Adjustment
Weight은 dropout을 하면 학습중에만 확률적으로 0이 되고, 이후 inference시에는 다시 되돌아 옵니다. 그런데 pruning을 하게 되면 network의 weight이 돌아오지 않게 되고 이는 sparse network를 가져옵니다. 그래서 pruning 이전과 같은 dropout rate을 적용하게 되면 network를 너무 sparse하게 만들 수 있어 dropout ratio를 다음과 같이 조절해줍니다.
C_i : i번째 layer의 non-zero weight 갯수
N_i : i번째 layer의 neuron 갯수
C_io : original network의 i번째 layer의 non-zero weight 갯수
C_ir : retraining network의 i번째 layer의 non-zero weight 갯수
D_o : original network의 dropout rate
D_r : retraining network의 dropout rate
그래서 retraining network의 dropout rate을 다음과 같이 정의하게 되면 retraining을 진행할 때마다 값이 점점 작아짐을 알 수 있습니다.
Experiment
LeNet, AlexNet, VGG와 같은 모델들의 경우 parameter는 약 10배 줄였음에도 불구하고 error가 조금은 줄어든 것으로 나타났습니다.
본 논문에는 이미지와 weights들을 visualization 함으로써 왜 pruning이 잘 작동하는가에 대해 다음과 같이 이야기를 했습니다. 예를 들어서, MNIST data는 글자나 숫자가 대부분 이미지의 중앙에 위치해있습니다. 그 주위에 다른 곳들은 아무런 정보를 갖고 있지 않습니다. 즉, 우리는 이미지의 중앙만 보고 그 이외의 자리는 보지 않더라도 이 글자가 어떤 글자인지 분류할 수 있습니다. Neural network는 이와 같이 중요하지 않은 부분들의 weights은 0에 수렴하게끔 학습이 됩니다. 그리고 우리는 이를 threshold를 기반으로 하여 중요하지 않은 weight들을 pruning함으로써 accuracy drop없이 model size를 줄일 수 있는 것입니다.