Tự Học Data Science · 23/10/2023 0

Chương 2 – Bài 8 – Sắp xếp mảng

Tự học Data Science tại Blog của Lưu

Sắp xếp Mảng trong NumPy: Hiệu quả và Ứng dụng trong AI

Chào mừng các bạn đến với bài viết mới trên blog của tôi! Trong bài viết này, chúng ta sẽ khám phá các thuật toán sắp xếp mảng trong NumPy – một thư viện quan trọng trong Python, đặc biệt hữu ích trong các ứng dụng AI và khoa học dữ liệu. Từ các phương pháp sắp xếp cơ bản như np.sort, np.argsort đến các kỹ thuật phân vùng như np.partition, chúng ta sẽ đi qua từng khái niệm với ví dụ thực tế, bao gồm cả ứng dụng tìm k láng giềng gần nhất (k-Nearest Neighbors). Hãy cùng tìm hiểu cách NumPy tối ưu hóa việc sắp xếp dữ liệu nhé!


Giới thiệu về Sắp xếp Mảng

Cho đến nay, chúng ta đã tập trung vào các công cụ để truy cập và thao tác dữ liệu mảng với NumPy. Phần này sẽ đề cập đến các thuật toán liên quan đến việc sắp xếp giá trị trong mảng NumPy. Đây là một chủ đề quen thuộc trong các khóa học nhập môn khoa học máy tính: nếu bạn từng học qua, chắc hẳn bạn đã nghe đến insertion sorts, selection sorts, merge sorts, quick sorts, bubble sorts, và nhiều thuật toán khác. Tất cả đều nhằm mục đích thực hiện một nhiệm vụ chung: sắp xếp các giá trị trong danh sách hoặc mảng.

Python cung cấp một số hàm và phương thức tích hợp để sắp xếp danh sách và các đối tượng khả lặp khác. Hàm sorted nhận một danh sách và trả về phiên bản đã sắp xếp của nó:

L = [3, 1, 4, 1, 5, 9, 2, 6]
sorted(L)  # trả về một bản sao đã sắp xếp
# Kết quả: [1, 1, 2, 3, 4, 5, 6, 9]

Ngược lại, phương thức sort của danh sách sẽ sắp xếp danh sách tại chỗ:

L.sort()  # thực hiện tại chỗ và trả về None
print(L)
# Kết quả: [1, 1, 2, 3, 4, 5, 6, 9]

Các phương thức sắp xếp của Python rất linh hoạt, có thể xử lý bất kỳ đối tượng khả lặp nào. Ví dụ, sắp xếp một chuỗi:

sorted('python')
# Kết quả: ['h', 'n', 'o', 'p', 't', 'y']

Tuy nhiên, như đã thảo luận trước đây, tính động của các giá trị trong Python khiến các phương thức này kém hiệu quả hơn so với các routine được thiết kế đặc biệt cho mảng số đồng nhất. Đây là lúc các công cụ sắp xếp của NumPy phát huy tác dụng.


Sắp xếp Nhanh trong NumPy: np.sortnp.argsort

Hàm np.sort tương tự như hàm sorted của Python, trả về một bản sao đã sắp xếp của mảng một cách hiệu quả:

import numpy as np

x = np.array([2, 1, 4, 3, 5])
np.sort(x)
# Kết quả: array([1, 2, 3, 4, 5])

Tương tự phương thức sort của danh sách Python, bạn cũng có thể sắp xếp mảng tại chỗ bằng phương thức sort của mảng:

x.sort()
print(x)
# Kết quả: [1 2 3 4 5]

Một hàm liên quan là argsort, thay vì trả về các phần tử đã sắp xếp, nó trả về chỉ số của các phần tử đã sắp xếp:

x = np.array([2, 1, 4, 3, 5])
i = np.argsort(x)
print(i)
# Kết quả: [1 0 3 2 4]

Phần tử đầu tiên trong kết quả là chỉ số của phần tử nhỏ nhất, phần tử thứ hai là chỉ số của phần tử nhỏ thứ hai, v.v. Các chỉ số này có thể được sử dụng (qua fancy indexing) để tạo mảng đã sắp xếp nếu cần:

x[i]
# Kết quả: array([1, 2, 3, 4, 5])

Chúng ta sẽ thấy ứng dụng của argsort sau trong bài này.

Sắp xếp Theo Hàng hoặc Cột

Một tính năng hữu ích của các thuật toán sắp xếp trong NumPy là khả năng sắp xếp theo hàng hoặc cột cụ thể của mảng đa chiều bằng đối số axis. Ví dụ:

rng = np.random.default_rng(seed=42)
X = rng.integers(0, 10, (4, 6))
print(X)
# Kết quả:
# [[0 7 6 4 4 8]
#  [0 6 2 0 5 9]
#  [7 7 7 7 5 1]
#  [8 4 5 3 1 9]]

# Sắp xếp từng cột của X
np.sort(X, axis=0)
# Kết quả:
# array([[0, 4, 2, 0, 1, 1],
#        [0, 6, 5, 3, 4, 8],
#        [7, 7, 6, 4, 5, 9],
#        [8, 7, 7, 7, 5, 9]])

# Sắp xếp từng hàng của X
np.sort(X, axis=1)
# Kết quả:
# array([[0, 4, 4, 6, 7, 8],
#        [0, 0, 2, 5, 6, 9],
#        [1, 5, 7, 7, 7, 7],
#        [1, 3, 4, 5, 8, 9]])

Lưu ý rằng cách này coi mỗi hàng hoặc cột như một mảng độc lập, và mọi mối quan hệ giữa các giá trị hàng hoặc cột sẽ bị mất!


Sắp xếp Một Phần: Phân vùng với np.partition

Đôi khi chúng ta không cần sắp xếp toàn bộ mảng mà chỉ muốn tìm k giá trị nhỏ nhất. NumPy cung cấp hàm np.partition cho việc này. Hàm np.partition nhận một mảng và số K, trả về một mảng mới với K giá trị nhỏ nhất nằm bên trái điểm phân vùng và các giá trị còn lại bên phải:

x = np.array([7, 2, 3, 1, 6, 5, 4])
np.partition(x, 3)
# Kết quả: array([2, 1, 3, 4, 6, 5, 7])

Lưu ý rằng ba giá trị đầu tiên trong mảng kết quả là ba giá trị nhỏ nhất, còn các vị trí còn lại chứa các giá trị khác với thứ tự bất kỳ.

Tương tự như sắp xếp, chúng ta có thể phân vùng theo trục bất kỳ của mảng đa chiều:

np.partition(X, 2, axis=1)
# Kết quả:
# array([[0, 4, 4, 7, 6, 8],
#        [0, 0, 2, 6, 5, 9],
#        [1, 5, 7, 7, 7, 7],
#        [1, 3, 4, 5, 8, 9]])

Kết quả là một mảng mà hai vị trí đầu tiên trong mỗi hàng chứa các giá trị nhỏ nhất của hàng đó, với các giá trị còn lại điền vào các vị trí còn lại.

Tương tự np.argsort, còn có hàm np.argpartition trả về chỉ số của phân vùng. Chúng ta sẽ thấy cả hai hàm này trong ví dụ tiếp theo.


Ví dụ: Tìm k Láng Giềng Gần Nhất (k-Nearest Neighbors)

Hãy xem cách sử dụng argsort trên nhiều trục để tìm các láng giềng gần nhất của mỗi điểm trong một tập hợp. Chúng ta bắt đầu bằng cách tạo một tập hợp ngẫu nhiên gồm 10 điểm trên mặt phẳng hai chiều, được sắp xếp trong một mảng $10\times 2$:

X = rng.random((10, 2))

Để hình dung các điểm này, hãy tạo một biểu đồ phân tán nhanh:

%matplotlib inline
import matplotlib.pyplot as plt
plt.style.use('seaborn-whitegrid')
plt.scatter(X[:, 0], X[:, 1], s=100);

Tính Khoảng Cách Giữa Các Đôi Điểm

Bây giờ, chúng ta sẽ tính khoảng cách bình phương giữa mỗi cặp điểm. Khoảng cách bình phương giữa hai điểm là tổng bình phương chênh lệch trên mỗi chiều. Sử dụng các routine broadcastingaggregation hiệu quả của NumPy, ta có thể tính ma trận khoảng cách bình phương trong một dòng mã:

dist_sq = np.sum((X[:, np.newaxis] - X[np.newaxis, :]) ** 2, axis=-1)

Nếu bạn chưa quen với quy tắc broadcasting của NumPy, đoạn mã này có thể hơi khó hiểu. Hãy phân tích từng bước:

# Tính chênh lệch tọa độ giữa mỗi cặp điểm
differences = X[:, np.newaxis] - X[np.newaxis, :]
differences.shape
# Kết quả: (10, 10, 2)

# Bình phương các chênh lệch tọa độ
sq_differences = differences ** 2
sq_differences.shape
# Kết quả: (10, 10, 2)

# Tổng các chênh lệch tọa độ để được khoảng cách bình phương
dist_sq = sq_differences.sum(-1)
dist_sq.shape
# Kết quả: (10, 10)

Để kiểm tra logic, đường chéo của ma trận này (khoảng cách giữa mỗi điểm với chính nó) phải bằng 0:

dist_sq.diagonal()
# Kết quả: array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

Sắp xếp để Tìm Láng Giềng Gần Nhất

Với ma trận khoảng cách bình phương, ta dùng np.argsort để sắp xếp theo từng hàng. Các cột bên trái sẽ cho chỉ số của các láng giềng gần nhất:

nearest = np.argsort(dist_sq, axis=1)
print(nearest)
# Kết quả:
# [[0 9 3 5 4 8 1 6 2 7]
#  [1 7 2 6 4 8 3 0 9 5]
#  [2 7 1 6 4 3 8 0 9 5]
#  [3 0 4 5 9 6 1 2 8 7]
#  [4 6 3 1 2 7 0 5 9 8]
#  [5 9 3 0 4 6 8 1 2 7]
#  [6 4 2 1 7 3 0 5 9 8]
#  [7 2 1 6 4 3 8 0 9 5]
#  [8 0 1 9 3 4 7 2 6 5]
#  [9 0 5 3 4 8 6 1 2 7]]

Cột đầu tiên là các số từ 0 đến 9, vì mỗi điểm là láng giềng gần nhất của chính nó.

Nếu chỉ cần k láng giềng gần nhất, ta không cần sắp xếp toàn bộ mà chỉ cần phân vùng sao cho k + 1 khoảng cách nhỏ nhất nằm đầu tiên:

K = 2
nearest_partition = np.argpartition(dist_sq, K + 1, axis=1)

Trực Quan Hóa Mạng Láng Giềng

Để hình dung, hãy vẽ các điểm cùng với các đường nối mỗi điểm đến hai láng giềng gần nhất:

plt.scatter(X[:, 0], X[:, 1], s=100)

# Vẽ đường từ mỗi điểm đến hai láng giềng gần nhất
K = 2
for i in range(X.shape[0]):
    for j in nearest_partition[i, :K+1]:
        plt.plot(*zip(X[j], X[i]), color='black')

Mỗi điểm trong biểu đồ có đường nối đến hai láng giềng gần nhất. Bạn có thể nhận thấy một số điểm có hơn hai đường xuất phát: điều này xảy ra vì nếu điểm A là một trong hai láng giềng gần nhất của điểm B, không có nghĩa là điểm B cũng là láng giềng gần nhất của điểm A.


Kết luận

Cách tiếp cận sử dụng broadcasting và sắp xếp theo hàng tuy có vẻ phức tạp hơn so với viết vòng lặp, nhưng lại rất hiệu quả trong Python. Thay vì lặp thủ công qua dữ liệu và sắp xếp từng tập láng giềng, cách vector hóa này nhanh hơn nhiều. Điều tuyệt vời là mã này không phụ thuộc vào kích thước dữ liệu: bạn có thể tính láng giềng cho 100 hay 1,000,000 điểm mà không cần thay đổi.

Đối với các tìm kiếm láng giềng gần nhất quy mô lớn, có các thuật toán dựa trên cây (như KD-Tree trong Scikit-Learn) với độ phức tạp $\mathcal{O}[N\log N]$ thay vì $\mathcal{O}[N^2]$ của thuật toán brute-force. Hy vọng bài viết này giúp bạn hiểu rõ hơn về sắp xếp trong NumPy và ứng dụng của nó trong AI. Hãy thử áp dụng và chia sẻ kết quả nhé!