0964 456 787 [email protected]

Vấn đề chúng ta thường gặp phải khi triển khai các model để giải quyết các vấn đề phức tạp, đó là làm sao sử dụng model với quá nhiều trọng số này để triển khai trên mobile cho người dùng hoặc 1 ứng dụng cần tính real-time. Rõ ràng, cũng có 1 phương pháp cho vấn đề này mà chúng ta có thể nghĩ tới ngay đó là triển khai model này trên 1 cloud với khả năng tính toán tốt rồi gọi nó khi nào chúng ta cần sử dụng dịch vụ. Tuy nhiên cách này yêu cầu người sử dụng cần phải kết nối internet, vì vậy nó sẽ hạn chế trong 1 số ứng dụng. Vậy điều chúng ta cần là phải triển khai được model trên các thiết bị mobile.

Mô hình Knowledge Distillation

Vậy vấn đề ở đây là gì? Chúng ta có thể huấn luyện một mạng NN mà nó có thể chạy trên một thiết bị hạn chế về tài nguyên tính toán như mobile. Nhưng có 1 vấn đề trong cách tiếp cận này. Các model nhỏ không thể trích xuất được nhiều các thông tin phức tạp. Trong trường hợp này chúng ta có thể sử dụng 2 phương pháp sau:

  • Knowledge Distillation
  • Model Compression

Ở bài viết này tôi sẻ giới thiệu về Knowledge Distillation. Model Compression sẽ được đề cập trong bài viết khác.

Vậy Knowledge Distillation là một cách đơn giản để cải thiện hiệu suất của Deep Learning Model trên thiết bị mobile. Trong quá trình này, chúng ta huấn luyện dữ liệu trên model lớn và phức tạp hoặc sử dụng ensemble model nhờ vào đó nó có thể trích xuất được các thông tin quan trọng có trong data và vì vậy cho kết quả dự đoán tốt hơn. Sau đó chúng ta huẩn luyện với model gọn nhẹ hơn với sự giúp đỡ của model trước. Khi đó model gọn nhẹ này có thể cho ra kết quả dự đoán tương đương.

GoogLeNet

Một ví dụ như sau, như cấu trúc mạng GoogLeNet như trên, nó rất phức tạp (có nghĩa là vừa deep vừa complex), nhờ được xếp với nhiều tầng như vậy nên nó có thể trích xuất được các thông tin phức tạp. Tuy nhiên model này thì quá cồng kềnh và không thể triển khai trên các thiết bị có cấu hình thấp. Vì vậy đây là lý do tại sao chúng ta cần phản chuyển những thông tin đã học được (knowledge learned) ở model phức tạp này tới 1 model đơn giản hơn. Ở phần này chúng ta xem model phức tạp là Teacher Nework và model mới đơn giản hơn là Student Network.

Model phức tạp học để phân biệt số lượng lớn các lớp. Một đối tượng huấn luyện thông thường sẽ làm tối đã hàm xác suất trung bình của kết quả chính xác và tính toán xác suất cho tất cả các lớp, với 1 số lớp sẽ có kết quả xác suất rất thấp. Xác suất tương đối của kết quả sai nói cho ta nhiều về độ phức tạp của model tạo ra. Ví dụ đối với một hình của chiếc xe hơi, ví dụ, xác suất để nó nhầm thành xe tải là thấp, nhưng xác suất này vẫn sẽ lớn hơn so với việc nó nhầm thành là một con chó.

Teach & Student

Bạn có thể ‘cô đọng’ (distill) một kiến trúc mạng lớn và phức tạp trong một kiến trúc mạng đơn giản hơn và kiến trúc mạng mới này có thể hoạt động tương đương với kiến trúc mạng trước kia.

Mô hình Teacher & Student

Tuy nhiên, sự khác nhau là mô hình student sẽ được huấn luyện để băt chước kết quả của kiến trúc mạng teacher thay vì được huẩn luyện trực tiếp từ tập dữ liệu ban đầu.

Cách thức hoạt động

Cách thức hoạt động Teacher - Student

Việc chuyển khả năng khái quát hóa của model phức tạp sang model đơn giản có thể được thực hiện bằng việc sử dụng các lớp xác suất sau khi qua hàm softmax của model phức tạp để huấn luyện cho model đơn giản. Đối với quá trình này, chúng ta sử dụng chung tập huấn luyện cho cả 2 model. Nguyên nhân chúng ta sử dụng soft label ở bước này vì soft label sẽ chứa nhiều thông tin hơn là hard label .

Phần lớn các thông tin đã được học nằng trong các tỷ lệ xác suất nhỏ của soft label. Đây là thông tin có giá trị xác định cấu trúc giống nhau của dữ liệu (Ví dụ như khi chúng ta nói đây là số 1 nhìn giống sô 7, ..) nhưng nó sẽ có giá trị nhỏ khi tính cross-entropy cost function bởi vì xác suất rất nhỏ.

Distillation

Để cô động các tri thức đã học được chúng ta sử dụng Logits (đầu vào của hàm softmax cuối cùng). Logits có thể được sử dụng cho việc học của model đơn giản và nó có thể thực hiện bằng cách tối thiểu sự sai khác bình thương giữa các logits được tạo ra bởi model phức tạp và model đơn giản.

\(P_t(a) = \frac{exp(p_t(a)/ \tau)}{ \sum_{i=1}^n exp(p_t(i)/\tau)}\)
Softmax with Temperature

Với tham số Temperatures cao (\(\tau -> \tex{inf}\)), tất cả các label sẽ có xác suất xấp xỉ và tại giá trị Temperatures thấp ( \(\tau -> 0\) ) xác suát của label với mong muốn cao nhất gần bằng 1.

Trong distillation, chúng ta sẽ tăng tham số Temperatures để cho kết quả của lớp final softmax của model phức tạp đạt được soft label phù hợp. Sau đó chúng ta sử dụng giá trị Temperature cao trong khi huấn luyện model đơn giản để phù hợp với soft label này.

Hàm mục tiêu

Hàm mục tiêu đầu tiên là cross-entropy với soft label và hàm này được tính toán sử dụng chung tham số Temperature giá trị lớn trong hàm softmax của distilled model như đã sử dụng để tạo ra soft label từ model phức tạp.

Hàm mục tiêu thứ 2 là cross-entropy với correct labels và nó được tính toán chính xác giống với logits trong hàm softmax của distilled model nhưng với giá trị temperature 1.

Huấn luyện ensembles model

Huấn luyện một ensembles model là một cách đơn giản để tận dụng lợi thế của tính toán song song. Nhưng thời gian để dự đoán qua model này rất lớn. Nhưng ở đây chúng ta có thể giải quyết với giải pháp mà bài viết này đang trình bày.

Ensemble model to Small model

(To be continue)

Mời các bạn đọc thêm các bài viết:

Tham khảo các bài viết nước ngoài: Tap into the dark knowledge using neural nets — Knowledge distillation