Paper Link : https://arxiv.org/abs/1802.05668
Github : https://github.com/antspy/quantized_distillation
스터디에서 발표했다. Quantized Knowledge Distillation 관한 내용이었는데, 관련 배경지식이 많이 부족해서 힘들었다. KD는 심플하게 Teacher Network의 soft label로 Student Network를 training 시키는 것으로 알고 있다. 이 과정이 Knowledge Transfer이고 이때 생기는 loss를 Distillation Loss라고 한다. 그런데 Quantization을 공부한다면 다음과 같은 생각을 해볼 수 있을 것이다. 적정 Criteria에 따라서 student network를 만들었는데 이 녀석을 애초에 quantized된 녀석으로 사용해서 training 시켜볼 수는 없을까? 라는 호기심. 이런 호기심에 의해 나오게된 논문이 ICLR2018에 소개된 "Model Compression via Distillation and Quantization"이라는 논문이다.
일단 student의 weight을 scaling할 때는 linear scaling을 사용하였고, scaling할 때 low-precision 문제로 bucketing이라는 테크닉을 사용했다. 그리고 난 뒤에는 uniform quantization을 이용했다. (사실 이 부분이 잘 안와닿았다.) bucketing하는 부분까지는 이해가 잘됐는데 이후에 uniform quantization을 왜 논문의 식과 같이 했는지 당위가 나에게는 보이지 않아서 답답했다.
이런식으로 student network를 quantization을 하고 나면 2가지 의문이 생긴다. 첫 번째는 knowledge transfer를 어떻게 시킬 것이냐? 이에 대한 대답은 간단하다. 2015년에 Hinton이 사용한 방법을 사용한다. 두 번째는 quantized neural network라는 상황에서 distillation loss를 어떻게 적용해서 사용할거냐. 이에 대한 대답은 "Projected Gradient Descent"이다. PGD는 이 논문을 통해서 처음 알게된 최적화 방법인데 최적화를 시켜주고 싶을 때 내가 생각하는 적당한 feasible set이 있다면 거기에 대해 제약을 걸어준 뒤 GD를 하는 방법이다. 이때 새로운 point를 update하는 방식이 feasible set에 대해 그 point를 projection 시키는 것과 같아서 PGD라는 이름이 붙었다.
이런 걸 통해서 Quantized Distillation이란 걸 하는데 과정은 다음과 같다.
1. Weight을 Quantization function을 통해 quantize해준다.
2. Forward Pass를 통해 distillation loss를 계산해준다.
3. Backward Pass를 통해 distillation loss에 대한 gradient loss를 계산해준다.
4. Parameter를 update할 때는 SGD를 이용하여 full precision으로 original weight을 update 해준다.
5. update된 weight을 다시 quantization function을 통해 quantize시켜준다.
Quantized Distillation을 어떻게 할 것인지에 대해서 생각을 해봤으면 이제 어떻게 해야 잘 Quantization을 할 수 있을까, student가 teacher를 잘 따라하게끔 quantized weight을 만들어 줄 수 있을까에 대한 고민이 생긴다. 이때 본 논문에서는 "Differentiable Quantization"이라는 것을 도입해서 해결한다. 결국 우리가 Quantization을 통해서 하고 싶은 건 어떤 quantization function을 통해서 만들어진 quantized vector가 accuracy loss를 최소화하게끔 만들고 싶은 건데 이걸 우리는 적당한 backpropagation을 통해서 update를 해주고 싶다. 그런데 backprop를 하려면 gradient를 계산해야 한다. 이때 문제가 생긴다. 각각의 quantization point를 결정하는 값들은 모두 불연속적이다. 그래서 gradient가 0이 된다. 그럼 우리가 backprop을 하는 의미가 없어진다. 이 방법을 해결하기 위해서 BNN 계열에서 자주 쓰이는 "Straight-Through Estimator"라는 테크닉을 사용한다. 이 트릭은 Hinton의 논문에서 맨처음 사용된 것이라고 하는데, 그리고 본 논문에서는 BinaryConnect라는 곳에서 비슷하게 썻다구 하는데 여튼 이걸 쓰면 gradient를 적당히 비틀어서 원래 gradient가 아닌 다른 gradient로 계산을 할 수 있게끔 해준다. 본 논문에서는 quantization function을 p_j에 대해서 gradient를 계산했을 때 v_i가 p_j에 대해 quantization됐으면 scaling factor를 뱉어내게끔 정의해뒀다. (이때 bucketing이 가정됐다면 bucketing size를 뱉어낼 것이다.)
이걸 갖구 Differentiable Quantization이란 걸 하는데 과정은 다음과 같다.
1. Weight을 Quantization function을 통해 quantize해준다.
2. Forward pass를 통해 loss를 계산해준다.
3. Backward pass를 통해 loss의 gradient를 계산해준다.
4. gradient를 계산할 때 STE를 이용해 계산한다.
5. SGD와 같은 최적화를 통해 quantization point를 update해준다.
실험은 본 논문에서 CNN, RNN을 통해서 진행했는데 나는 CNN에 관한 것만 리뷰를 했다. 여기서는 training이후에 다른 부가적인 operation 없이 uniform quantization을 진행하는 PM quantization과 논문에서 소개한 방법들을 비교해뒀다. 재밌었던 것 Parameter가 많지 않은 network에 대해서는 PM quantization이나 논문에서 소개된 방법이나 그렇게 크게 차이가 나보이지 않았다. 물론 10-20%는 엄청난 차이지만. 그런데 ResNet에 대해서는 PM quantization이 2-bit quantization을 했을 때 상당한 accuracy drop을 보였다. accuracy가 약 10%였다. 하지만 논문에서 소개한 방법은 80%를 넘는 accuracy를 보였다. 이는 아마 PM quantization은 training 이후에 quantization을 하는 것이라 그런 것일 거구, 논문에 소개된 방법은 training 과정 중에 quantization을 같이 하기 때문에 그런 것일 거다. 이 점이 재밌었고. 한가지로 똑같은 실험에서 논문에서 소개한 방법으로 2-bit quantization을 한 것이나 4-bit quantization을 한 것이나 accuracy가 거의 차이가 나지 않았다는 것이 재밌었다.
스터디에서 첫 발표였구 나에게는 여러모로 뜻깊은 경험이었다. 다음 발표 때는 더 잘 준비해야지..