LSH算法 1.LSH 算法原理分步详解 LSH 算法原理分步详解(Locality-Sensitive Hashing) 1.1、核心思想 目的:在高维空间中快速找到“相似”的数据点。 关键思想:让相似的样本在哈希后 落入同一个桶(bucket) 的概率高,不相似的样本落入同桶的概率低。 相比暴力搜索(O(N)),LSH 通过 多组随机哈希函数 将搜索复杂度降低到 亚线性(sublinear)。 1.
** LSH 算法原理分步详解(Locality-Sensitive Hashing)**
目的:在高维空间中快速找到“相似”的数据点。
关键思想:让相似的样本在哈希后 落入同一个桶(bucket) 的概率高,不相似的样本落入同桶的概率低。
相比暴力搜索(O(N)),LSH 通过 多组随机哈希函数 将搜索复杂度降低到 亚线性(sublinear)。
| 相似度度量 | 常用 LSH 变体 | 哈希思想 |
|---|---|---|
| 余弦相似度(cosine similarity) | Random Projection LSH | 用随机超平面划分空间 |
| 欧式距离(L2 distance) | p-stable LSH | 用随机投影 + 模运算近似欧氏距离 |
| Jaccard 相似度(集合相似度) | MinHash LSH | 利用最小哈希值近似集合交并比 |
在 d 维空间随机生成 k 个向量,这些向量充当超平面的法向量(Normal Vector)。
每个向量的分量服从标准正态分布
每个随机向量 r_i 唯一确定了一个经过原点且与其垂直的超平面。我们通过判断数据点位于该超平面的哪一侧来进行哈希编码。
对于一个向量 x,我们计算它与 k 个超平面的法向量 r_i 的点积。根据点积结果的正负,生成对应的二进制位:
将这 k 个结果拼接起来,就形成了一个长度为 k 的二进制串(例如 10100110)。这个二进制串就是向量 x 的 哈希签名。
✅ 直觉:两个向量的夹角越小,被随机超平面分开的概率越低(更可能在同侧),因此哈希签名越相似。
每个样本被插入到 L 个哈希表对应的桶中,从而提升召回率。
给定查询向量 q:
| 参数 | 含义 | 影响 |
|---|---|---|
| k | 每组哈希函数数量 | 越大 → 桶更小,召回率下降但精度提高 |
| L | 哈希表数量 | 越多 → 召回率上升但内存消耗大 |
| n | 样本数量 | 影响查询速度,越多收益越明显 |
一般经验值:
| 阶段 | 时间复杂度 | 空间复杂度 |
|---|---|---|
| 建表 | O(nLk) | O(nL) |
| 查询 | O(L(k + c)) ,其中 c 为候选数量 | - |
相比暴力搜索 O(n) ,LSH 查询复杂度可达 亚线性级(如 O(n^{0.5}))。
LSH算法Python实现
import numpy as np from typing import List, Union, Dict class CosineLSH: """ 基于随机超平面投影的局部敏感哈希(LSH),适用于余弦相似度测量。 参数: hash_size (int): 单个哈希表的哈希函数数量(即哈希码的位数) num_tables (int): 使用的哈希表数量 """ def __init__(self, hash_size: int = 6, num_tables: int = 5): self.hash_size = hash_size # 每个哈希表的位数 self.num_tables = num_tables # 哈希表数量 self.hash_tables = [dict() for _ in range(num_tables)] # 初始化哈希表 self.random_planes_list = [] # 存储每个哈希表的随机超平面 self.dimension = None # 数据维度(在插入数据时确定) self.data = None # 存储数据向量(便于后续使用和调试) def _generate_random_planes(self, dimension: int) -> np.ndarray: """为单个哈希表生成随机超平面(每个超平面对应一个哈希函数)""" return np.random.randn(self.hash_size, dimension) def _hash(self, vector: np.ndarray, random_planes: np.ndarray) -> str: """计算单个向量的哈希键(二进制字符串)""" # 计算向量与每个随机超平面的点积,根据符号生成二进制位 projections = np.dot(vector, random_planes.T) hash_bits = (projections > 0).astype(int) # 大于0为1,否则为0 return ''.join(hash_bits.astype(str)) # 转换为二进制字符串作为哈希键 def index(self, data: Union[List[List[float]], np.ndarray]) -> None: """ 将数据向量插入LSH索引中 参数: data: 待索引的向量列表或数组 """ data_array = np.array(data) if len(data_array.shape) == 1: data_array = data_array.reshape(1, -1) self.dimension = data_array.shape[1] # 设置数据维度 self.data = data_array # 存储数据(便于后续使用和调试) # 为每个哈希表生成随机超平面 self.random_planes_list = [ self._generate_random_planes(self.dimension) for _ in range(self.num_tables) ] # 将每个向量插入所有哈希表 for i, vector in enumerate(data_array): for table_idx in range(self.num_tables): hash_key = self._hash(vector, self.random_planes_list[table_idx]) # 将向量索引存入对应哈希桶 if hash_key in self.hash_tables[table_idx]: self.hash_tables[table_idx][hash_key].append(i) else: self.hash_tables[table_idx][hash_key] = [i] def query(self, query_vector: Union[List[float], np.ndarray], max_results: int = 10, return_all_candidates: bool = False, sort_by_distance: bool = False, return_distances: bool = False): """ 查询与给定向量相似的向量 参数: query_vector: 查询向量 max_results: 返回的最大结果数量(当return_all_candidates=False时生效) return_all_candidates: 是否返回所有候选向量(默认False) sort_by_distance: 是否按距离排序(默认False) return_distances: 是否返回距离(默认False) 返回: 如果return_distances=False: 相似向量的索引列表 如果return_distances=True: (索引列表, 距离数组) 的元组 """ if self.dimension is None: raise ValueError("请先使用index方法插入数据") query_vec = np.array(query_vector) candidates = set() # 在所有哈希表中查找候选向量 for table_idx in range(self.num_tables): hash_key = self._hash(query_vec, self.random_planes_list[table_idx]) if hash_key in self.hash_tables[table_idx]: candidates.update(self.hash_tables[table_idx][hash_key]) # 如果没有找到候选向量,尝试查找邻近桶 if not candidates: print("未找到精确匹配的候选向量,正在搜索邻近桶...") for table_idx in range(self.num_tables): original_key = self._hash(query_vec, self.random_planes_list[table_idx]) # 查找哈希码只有1位不同的桶 for i in range(self.hash_size): neighbor_key = list(original_key) neighbor_key[i] = '1' if neighbor_key[i] == '0' else '0' neighbor_key = ''.join(neighbor_key) if neighbor_key in self.hash_tables[table_idx]: candidates.update(self.hash_tables[table_idx][neighbor_key]) if not candidates: if return_distances: return np.array([], dtype=int), np.array([]) return [] cand_ids = np.array(list(candidates), dtype=int) # 如果需要排序,计算距离并排序 if sort_by_distance: dists = np.linalg.norm(self.data[cand_ids] - query_vec, axis=1) sorted_indices = np.argsort(dists) cand_ids = cand_ids[sorted_indices] if return_distances: dists = dists[sorted_indices] # 决定返回多少个结果 if not return_all_candidates: cand_ids = cand_ids[:max_results] if sort_by_distance and return_distances: dists = dists[:max_results] if return_distances: return cand_ids, dists return list(cand_ids) def get_hash_tables_info(self) -> Dict: """返回哈希表的统计信息""" info = { 'num_tables': self.num_tables, 'hash_size': self.hash_size, 'total_buckets': 0, 'average_bucket_size': 0, 'table_details': [] } total_vectors = 0 for i, table in enumerate(self.hash_tables): num_buckets = len(table) vectors_in_table = sum(len(bucket) for bucket in table.values()) total_vectors += vectors_in_table avg_size = vectors_in_table / num_buckets if num_buckets > 0 else 0 info['table_details'].append({ 'table_index': i, 'num_buckets': num_buckets, 'total_vectors': vectors_in_table, 'average_bucket_size': avg_size }) info['total_buckets'] = sum(detail['num_buckets'] for detail in info['table_details']) if info['total_buckets'] > 0: info['average_bucket_size'] = (total_vectors / info['total_buckets']) return info def query_with_topk(self, query_vector: Union[List[float], np.ndarray], k: int = 5): """ 查询并返回候选集和Top-k结果(便捷方法,用于可视化) 参数: query_vector: 查询向量 k: 返回的Top-k数量 返回: (候选ID数组, Top-k ID数组) 的元组 """ # 获取所有候选向量 all_candidates = self.query(query_vector, return_all_candidates=True, sort_by_distance=False) if len(all_candidates) == 0: return np.array([], dtype=int), np.array([], dtype=int) # 计算Top-k cand_ids = np.array(all_candidates, dtype=int) dists = np.linalg.norm(self.data[cand_ids] - np.array(query_vector), axis=1) topk = cand_ids[np.argsort(dists)[:k]] return cand_ids, topk def _query_single_table(self, query_vec, table_idx): """查询单张哈希表,返回候选点ID集合(用于可视化)""" hash_key = self._hash(query_vec, self.random_planes_list[table_idx]) return set(self.hash_tables[table_idx].get(hash_key, [])) def _compute_topk_from_candidates(self, query_vec, candidate_ids, k): """从候选集中计算Top-k最近邻(用于可视化)""" if len(candidate_ids) == 0: return np.array([], dtype=int) dists = np.linalg.norm(self.data[candidate_ids] - query_vec, axis=1) return candidate_ids[np.argsort(dists)[:k]] def compute_matches(self, query_vec): """ 计算所有数据点与查询点的哈希匹配情况(用于可视化) 返回: bits: 数据点哈希编码 (N, L, Bits) q_bits: 查询点哈希编码 (L, Bits) matches: 每个点在每张表是否匹配 (N, L) hit_counts: 每个点命中的表数 (N,) candidates: 至少命中一张表的点索引 """ if self.data is None: raise ValueError("请先使用index方法插入数据") # 将random_planes_list转换为(n_tables, hash_size, dimension)格式 planes = np.array(self.random_planes_list) # (num_tables, hash_size, dimension) # 计算投影: (N, num_tables, hash_size) proj = np.dot(self.data, planes.transpose(0, 2, 1)) q_proj = np.dot(query_vec, planes.transpose(0, 2, 1)) bits = proj > 0 q_bits = q_proj > 0 matches = np.all(bits == q_bits, axis=2) hit_counts = matches.sum(axis=1) return bits, q_bits, matches, hit_counts, np.where(hit_counts > 0)[0] @property def planes(self): """返回超平面数组,格式为 (num_tables, hash_size, dimension)(用于可视化兼容性)""" if len(self.random_planes_list) == 0: return None return np.array(self.random_planes_list) @property def num_bits(self): """返回哈希位数(用于可视化兼容性)""" return self.hash_size # 示例使用和测试 if __name__ == "__main__": # 生成示例数据 np.random.seed(42) # 设置随机种子以确保结果可重现 data_vectors = np.random.randn(100, 10) # 100个10维向量 # 创建LSH索引 print("正在构建LSH索引...") lsh = CosineLSH(hash_size=8, num_tables=3) lsh.index(data_vectors) # 显示哈希表统计信息 info = lsh.get_hash_tables_info() print(f"\n哈希表统计信息:") print(f"哈希表数量: {info['num_tables']}") print(f"总桶数: {info['total_buckets']}") print(f"平均每个桶的向量数: {info['average_bucket_size']:.2f}") # 查询示例 query_vec = data_vectors[0] # 使用第一个向量作为查询 print(f"\n查询向量索引: 0") similar_indices = lsh.query(query_vec, max_results=5) print(f"找到的相似向量索引: {similar_indices}") # 验证结果:计算实际余弦相似度 from sklearn.metrics.pairwise import cosine_similarity print("\n相似度验证:") for idx in similar_indices: similarity = cosine_similarity([query_vec], [data_vectors[idx]])[0][0] print(f"向量 {idx} 与查询向量的余弦相似度: {similarity:.4f}") # 对比线性搜索结果 print("\n=== 与线性搜索对比 ===") all_similarities = cosine_similarity([query_vec], data_vectors)[0] top_linear = np.argsort(all_similarities)[::-1][1:6] # 排除自身,取前5个 print(f"线性搜索Top-5结果: {top_linear}") # 计算召回率 lsh_recall = len(set(similar_indices) & set(top_linear)) / len(top_linear) print(f"LSH召回率(与真实Top-5相比): {lsh_recall:.2%}")
运行演示:
正在构建LSH索引... 哈希表统计信息: 哈希表数量: 3 总桶数: 189 平均每个桶的向量数: 1.59 查询向量索引: 0 找到的相似向量索引: [0, 48, 20, 54, 86] 相似度验证: 向量 0 与查询向量的余弦相似度: 1.0000 向量 48 与查询向量的余弦相似度: -0.2653 向量 20 与查询向量的余弦相似度: 0.5041 向量 54 与查询向量的余弦相似度: 0.4609 向量 86 与查询向量的余弦相似度: 0.6143 === 与线性搜索对比 === 线性搜索Top-5结果: [91 32 15 86 20] LSH召回率(与真实Top-5相比): 40.00%
import numpy as np import matplotlib.pyplot as plt from matplotlib import gridspec from matplotlib.patches import Rectangle from pathlib import Path plt.rcParams['font.sans-serif'] = ['Arial Unicode MS', 'SimHei', 'Microsoft YaHei'] plt.rcParams['axes.unicode_minus'] = False # 为了后续可视化方便,这里统一配置颜色和绘图配置。 # 颜色配置 COLORS = { "data": "#E0E0E0", # 数据点颜色(浅灰色) "query": "#FFD93D", # 查询点颜色(亮黄色) "candidates": "#777777", # 候选点颜色(中灰色) "tables": ["#E74C3C", "#3498DB", "#9B59B6"], # 不同哈希表的颜色标识 "hits": ["#A8DADC", "#76C893", "#4A7C59"], # 命中次数对应的颜色渐变 "planes": ["#A8DADC", "#3498DB", "#9B59B6", "#2ECC71"], # 超平面颜色循环使用 "hash_table": { "match_bg": "#8FBC8F", # 哈希码匹配的背景色 "no_match_bg": "#FAFAFA", # 哈希码不匹配的背景色 "match_border": "#43A047", # 哈希码匹配的边框色 "no_match_border": "#E0E0E0", # 哈希码不匹配的边框色 "no_hit_text": "#999", # 未命中时的文字颜色 }, } # 绘图配置 PLOT_CONFIG = { "figsize": (16, 10), "fontsize": {"title": 11, "legend": 8, "text": 10, "suptitle": 16, "panel_title": 12}, "alpha": { "data": 0.4, # 数据点透明度 "candidates": 0.9, # 候选点透明度(较不透明,突出显示) "hits": 0.8, # 命中点透明度(中等透明度) "planes": 0.8, # 超平面透明度(清晰可见但不抢镜) "bbox": 0.9, # 文本框背景透明度 "grid": 0.2 # 网格线透明度(很淡,只做参考) }, "sizes": {"data": 20, "query": 120, "candidates": 60, "hit_points": 50}, "ground_truth": { "face": "none", # Ground truth点填充色(空心) "edge": "#FF8C00", # Ground truth点边框色(深橙色) "text": "#FF8C00", # Ground truth文本颜色(深橙色) "marker": "s", # Ground truth点形状(方框) "alpha": 0.95, # Ground truth点透明度 "size": 100, # 统一的ground truth点大小 "linewidth": 1, # 统一的线条粗细 "zorder": 8, # 统一的图层层级 }, "grid": { "hspace": 0.3, # 子图垂直间距 "wspace": 0.25 # 子图水平间距 }, "table_wspace": 0.15, # 多表哈希子图间距 "contour_resolution": 200, # 背景网格分辨率 "save": {"dpi": 150}, "planes": { "linewidth": 2.5, "arrow": { # 超平面法向量箭头定位参数 "scale": 0.12, # 箭头大小缩放因子 "t_min": 0.05, # 箭头位置最小偏移 "t_max": 0.33, # 箭头位置最大偏移 "t_base": 0.25, # 箭头位置基础偏移 "t_step": 0.04 # 多箭头间的偏移步长 } }, "layout": { "subplots_adjust": { # 子图整体位置调整 "top": 0.92, "bottom": 0.06, "left": 0.05, "right": 0.98 } } } def _line_in_box(w, xlim, ylim): """计算超平面 w·x=0 与绘图边界框的交点,返回线段两端点""" pts = [] eps = 1e-12 # 与左右边界 (x=常数) 求交 if abs(w[1]) > eps: for x in xlim: y = -w[0] * x / w[1] if ylim[0] - 1e-9 <= y <= ylim[1] + 1e-9: pts.append([x, y]) # 与上下边界 (y=常数) 求交 if abs(w[0]) > eps: for y in ylim: x = -w[1] * y / w[0] if xlim[0] - 1e-9 <= x <= xlim[1] + 1e-9: pts.append([x, y]) if len(pts) < 2: return None # 去重:保留第一个点,然后添加与第一个点不接近的点 unique = [pts[0]] for p in pts[1:]: if not np.allclose(p, unique[0]): unique.append(p) if len(unique) >= 2: break return (np.array(unique[0]), np.array(unique[1])) if len(unique) >= 2 else None def _draw_hyperplanes(ax, planes, xlim, ylim, query, span, colors, table_idx=0, show_arrows=True): """绘制超平面及其标注""" for i, p in enumerate(planes): seg = _line_in_box(p, xlim, ylim) if not seg: continue p1, p2 = seg # Panel 1: 关注不同超平面切分,每个超平面用不同颜色,Panel 2: 关注不同的切分效果,每个表用一种颜色 color = colors["planes"][i] if show_arrows else colors["tables"][table_idx] label = f"超平面 {i+1}" if show_arrows else (f"T{table_idx+1} 超平面" if i == 0 else None) ax.plot([p1[0], p2[0]], [p1[1], p2[1]], c=color, lw=PLOT_CONFIG["planes"]["linewidth"], alpha=PLOT_CONFIG["alpha"]["planes"], label=label) if show_arrows and table_idx == 0: # 绘制法向量箭头(只在Panel 1显示) low, high = (p1, p2) if p1[1] <= p2[1] else (p2, p1) n_hat = p / (np.linalg.norm(p) + 1e-12) side = 1.0 if np.dot(query, p) >= 0 else -1.0 arrow_config = PLOT_CONFIG["planes"]["arrow"] t = np.clip(arrow_config["t_base"] + (i - (len(planes) - 1) / 2) * arrow_config["t_step"], arrow_config["t_min"], arrow_config["t_max"]) base = low + (high - low) * t arrow = side * n_hat * arrow_config["scale"] * span ax.annotate("", xy=base + arrow, xytext=base, arrowprops=dict(arrowstyle="-|>", color=color, lw=2.5), zorder=6) ax.text(*(base + arrow * 1.15), f"n{i+1}={'1' if side >= 0 else '0'}", fontsize=9, fontweight="bold", c=color, ha="center", bbox=dict(boxstyle="round,pad=0.15", fc="w", ec="none", alpha=PLOT_CONFIG["alpha"]["bbox"]), zorder=7) def setup_plot(ax, xlim, ylim, data, query, title="", data_alpha=None, data_color=None, data_size=None, query_label=True): """统一的绘图设置函数:初始化子图并绘制数据点和查询点""" ax.set(xlim=xlim, ylim=ylim, aspect="equal") if title: ax.set_title(title, fontsize=PLOT_CONFIG["fontsize"]["title"], fontweight="bold") # 绘制数据点 ax.scatter(data[:, 0], data[:, 1], c=data_color or COLORS["data"], s=data_size or PLOT_CONFIG["sizes"]["data"], alpha=data_alpha or PLOT_CONFIG["alpha"]["data"], zorder=1) # 绘制查询点 ax.scatter(*query, c=COLORS["query"], marker="*", s=PLOT_CONFIG["sizes"]["query"], ec="k", lw=1.5, zorder=10, label="查询点" if query_label else None) def draw_candidates(ax, data, matches_or_hits, table_idx=None, lsh=None, size=None, alpha=None, show_prefix=True): """绘制候选点:table_idx指定时按表着色,否则按命中次数着色""" size = size or PLOT_CONFIG["sizes"]["candidates"] alpha = alpha or PLOT_CONFIG["alpha"]["candidates"] if table_idx is not None: # 按表着色 matches = matches_or_hits if matches[:, table_idx].any(): pts = data[matches[:, table_idx]] prefix = f"T{table_idx+1} " if show_prefix else "" ax.scatter(pts[:, 0], pts[:, 1], c=COLORS["candidates"], s=size, alpha=alpha, ec="w", lw=1, zorder=5, label=f"{prefix}候选 ({len(pts)})") else: # 按命中次数着色 if lsh is None: raise ValueError("当table_idx=None时,必须提供lsh参数") hits = matches_or_hits for h in range(1, lsh.num_tables + 1): idx = np.where(hits == h)[0] if len(idx): ax.scatter(data[idx, 0], data[idx, 1], c=COLORS["hits"][h-1], s=size, alpha=min(PLOT_CONFIG["alpha"]["hits"] + h * 0.05, 1.0), ec="w", lw=1, zorder=h+2, label=f"命中{h}/{lsh.num_tables} ({len(idx)}点)") def _setup_legend(ax, fontsize=None, loc="upper left"): """统一的图例设置函数""" fontsize = fontsize or PLOT_CONFIG["fontsize"]["legend"] ax.legend(fontsize=fontsize, loc=loc) def draw_ground_truth(ax, data, gt_indices, recall=None, size_scale=1.0): """ground Truth绘制函数""" if gt_indices is None or len(gt_indices) == 0: return gt_config = PLOT_CONFIG["ground_truth"] label = f"Ground Truth (召回率: {recall:.1%})" if recall is not None else "Ground Truth" ax.scatter(data[gt_indices, 0], data[gt_indices, 1], c=gt_config["face"], marker=gt_config["marker"], s=gt_config["size"] * size_scale, ec=gt_config["edge"], lw=gt_config["linewidth"], alpha=gt_config["alpha"], label=label, zorder=gt_config["zorder"]) def print_lsh_statistics(lsh, data): """打印LSH索引统计信息""" print(f"\nLSH索引统计:") print(f"向量总数: {len(data)}") print(f"哈希表数量: {lsh.num_tables}") print(f"每表哈希位数: {lsh.hash_size}") # 计算总桶数 total_buckets = sum(len(table) for table in lsh.hash_tables) print(f"总桶数: {total_buckets}") # 打印每个表的详细信息 for ti, table in enumerate(lsh.hash_tables): avg_bucket_size = np.mean([len(bucket) for bucket in table.values()]) if table else 0 print(f"表{ti+1}: {len(table)}个桶, 平均每个桶{avg_bucket_size:.2f}个向量") def _compute_single_table_recall(lsh, query, data, gt, k, table_idx): """计算单表的召回率""" single_candidates = lsh._query_single_table(query, table_idx) if not single_candidates: return 0.0 cand_ids = np.array(list(single_candidates)) single_topk = lsh._compute_topk_from_candidates(query, cand_ids, k) return len(set(single_topk) & set(gt)) / k if k else 0.0 def visualize_lsh(lsh, query_vec, k=5): """可视化 LSH 完整流程""" if lsh.data is None: print("没有数据") return data = lsh.data query = np.asarray(query_vec) n = len(data) bits, q_bits, matches, hits, cands = lsh.compute_matches(query) xlim = (data[:, 0].min() - 1, data[:, 0].max() + 1) ylim = (data[:, 1].min() - 1, data[:, 1].max() + 1) span = min(xlim[1] - xlim[0], ylim[1] - ylim[0]) # 预先计算召回率相关数据 gt = np.argsort(np.linalg.norm(data - query, axis=1))[:k] # 真实 Top-k _, lsh_topk = lsh.query_with_topk(query, k) # LSH Top-k fig = plt.figure(figsize=PLOT_CONFIG["figsize"]) gs = gridspec.GridSpec(2, 6, **PLOT_CONFIG["grid"]) # Panel 1: 空间划分 ax1 = fig.add_subplot(gs[0, :2]) setup_plot(ax1, xlim, ylim, data, query, "1. 超平面空间划分 (Table 1)") # 用网格着色展示空间划分 xx, yy = np.meshgrid(np.linspace(*xlim, PLOT_CONFIG["contour_resolution"]), np.linspace(*ylim, PLOT_CONFIG["contour_resolution"])) gp = np.c_[xx.ravel(), yy.ravel()] planes = lsh.planes hv = np.sum((np.dot(gp, planes[0].T) > 0) * (2 ** np.arange(lsh.num_bits)), axis=1) ax1.contourf(xx, yy, hv.reshape(xx.shape), levels=2**lsh.num_bits, cmap="tab20", alpha=0.1) # 使用统一的超平面绘制函数 _draw_hyperplanes(ax1, planes[0], xlim, ylim, query, span, COLORS, table_idx=0, show_arrows=True) # 绘制候选点和ground truth点 draw_candidates(ax1, data, matches, table_idx=0, show_prefix=False) draw_ground_truth(ax1, data, gt) # Hash框显示在图层最上层 ax1.annotate(f"Hash: [{''.join(q_bits[0].astype(int).astype(str))}]", xy=query, xytext=(query[0]+0.8, query[1]+0.8), fontsize=10, fontweight="bold", bbox=dict(boxstyle="round,pad=0.3", fc="w", alpha=0.9), arrowprops=dict(arrowstyle="->", color="gray"), zorder=20) _setup_legend(ax1) ax1.grid(True, alpha=0.2) # Panel 2: 多表策略 (显示单表召回率) gs2 = gridspec.GridSpecFromSubplotSpec(1, lsh.num_tables, subplot_spec=gs[0, 2:6], wspace=PLOT_CONFIG["table_wspace"]) # 计算每个单表的召回率 table_recalls = [_compute_single_table_recall(lsh, query, data, gt, k, ti) for ti in range(lsh.num_tables)] for ti in range(lsh.num_tables): ax = fig.add_subplot(gs2[0, ti]) setup_plot(ax, xlim, ylim, data, query, f"Table {ti+1}") planes = lsh.planes _draw_hyperplanes(ax, planes[ti], xlim, ylim, query, span, COLORS, table_idx=ti, show_arrows=False) draw_candidates(ax, data, matches, table_idx=ti, size=PLOT_CONFIG["sizes"]["candidates"]*0.6) draw_ground_truth(ax, data, gt, recall=table_recalls[ti], size_scale=0.6) _setup_legend(ax, fontsize=PLOT_CONFIG["fontsize"]["legend"]-1) fig.text(0.67, 0.92, f"2. 多表策略 (L={lsh.num_tables})", fontsize=PLOT_CONFIG["fontsize"]["panel_title"], fontweight="bold", ha="center") # Panel 3: 联合召回 (按命中次数着色) ax3 = fig.add_subplot(gs[1, :2]) setup_plot(ax3, xlim, ylim, data, query, "3. 联合召回 (按命中次数着色)", data_color="lightgray", data_size=PLOT_CONFIG["sizes"]["data"]*0.5, data_alpha=0.3, query_label=False) draw_candidates(ax3, data, hits, lsh=lsh) # 绘制ground truth点 draw_ground_truth(ax3, data, gt) ax3.text(0.5, 0.02, f"候选集: {len(cands)}点 ({len(cands)/n:.0%})", transform=ax3.transAxes, ha="center", fontsize=10, fontweight="bold", bbox=dict(boxstyle="round", fc="w", alpha=PLOT_CONFIG["alpha"]["bbox"], ec="gray")) _setup_legend(ax3) # Panel 4: 哈希编码对比 ax4 = fig.add_subplot(gs[1, 2:4]) ax4.axis("off") ax4.set_title("4. 哈希编码对比", fontsize=PLOT_CONFIG["fontsize"]["title"], fontweight="bold") # 选择代表性的点展示: ground truth点 + 每个命中表数至少一个点(优先非ground truth) show_idx, seen_patterns, gt_set = [], set(), set(gt) def add(idx): pattern = tuple(tuple(bits[int(idx), ti].astype(int)) for ti in range(lsh.num_tables)) if pattern not in seen_patterns: show_idx.append(int(idx)) seen_patterns.add(pattern) return True # 添加ground truth点 for idx in gt: add(idx) # 为每个命中表数至少选一个点(优先非ground truth) for h in range(lsh.num_tables, -1, -1): candidates = np.where(hits == h)[0] if len(candidates) == 0: continue for idx in sorted(candidates, key=lambda x: int(x) in gt_set): if add(idx): break if show_idx: xs = [1.8 * i for i in range(lsh.num_tables + 2)] ax4.set_xlim(xs[0] - 1, xs[-1] + 1) ax4.set_ylim(-(len(show_idx) + 2.5) * 0.9, 1) def cell(x, y, txt, bg=None, ec=None, lw=0, **kw): if bg: ax4.add_patch(Rectangle((x-0.8, y-0.35), 1.6, 0.7, fc=bg, ec=ec, lw=lw)) ax4.text(x, y, txt, ha="center", va="center", **kw) headers = ["ID"] + [f"T{j+1}" for j in range(lsh.num_tables)] + ["命中"] for i, (h, x) in enumerate(zip(headers, xs)): c = COLORS["tables"][i-1] if 1 <= i <= lsh.num_tables else "k" cell(x, 0.55, h, fontsize=9, fontweight="bold", color=c) if 1 <= i <= lsh.num_tables: ax4.add_patch(Rectangle((x-0.85, 0.25), 1.7, 0.6, fc=["#FFEBEE", "#E3F2FD", "#F3E5F5"][(i-1)%3], ec=c, lw=2, zorder=-1)) ax4.add_patch(Rectangle((xs[0]-0.5, -0.9), xs[-1]-xs[0]+1, 0.8, fc="#FFF9C4", ec=COLORS["query"], lw=2)) cell(xs[0], -0.5, "Query", fontsize=9, fontweight="bold", color="#F57F17") for ti in range(lsh.num_tables): cell(xs[ti+1], -0.5, ''.join(q_bits[ti].astype(int).astype(str)), fontsize=9, family="monospace", fontweight="bold") for i, idx in enumerate(show_idx): y = -(i + 1.5) * 0.9 - 0.3 # 哈希表布局参数 hc = int(hits[idx]) # 如果是 ground truth 点,用特殊标记 point_label = f"V{idx}◆" if idx in gt else f"V{idx}" point_color = PLOT_CONFIG["ground_truth"]["text"] if idx in gt else "black" cell(xs[0], y, point_label, fontsize=9, color=point_color, fontweight="bold" if idx in gt else "normal") for ti in range(lsh.num_tables): m = matches[idx, ti] cell(xs[ti+1], y, ''.join(bits[idx, ti].astype(int).astype(str)), bg=COLORS["hash_table"]["match_bg"] if m else COLORS["hash_table"]["no_match_bg"], ec=COLORS["hash_table"]["match_border"] if m else COLORS["hash_table"]["no_match_border"], lw=2 if m else 1, fontsize=8, family="monospace") cell(xs[-1], y, f"{hc}/{lsh.num_tables}", fontsize=10, fontweight="bold", color=COLORS["hits"][hc-1] if hc else COLORS["hash_table"]["no_hit_text"]) ax4.text((xs[0]+xs[-1])/2, y-1.5, "绿框 = 哈希码匹配 = 被召回候选集\n◆ = Ground Truth", ha="center", fontsize=8, color=COLORS["hash_table"]["match_border"]) # Panel 5: 最终结果对比 recall = len(set(lsh_topk) & set(gt)) / k if k else 0 ax5 = fig.add_subplot(gs[1, 4:6]) setup_plot(ax5, xlim, ylim, data, query, "5. 结果对比") ax5.collections[0].set_alpha(0.3) ax5.collections[0].set_facecolor("lightgray") draw_ground_truth(ax5, data, gt, recall=recall) if len(lsh_topk): ax5.scatter(data[lsh_topk, 0], data[lsh_topk, 1], c="red", marker="o", s=40, alpha=0.8, label="LSH结果") for i, idx in enumerate(gt[:5]): ax5.annotate(f"{i+1}", data[idx], xytext=(3, 3), textcoords="offset points", fontsize=8, color="w", fontweight="bold") _setup_legend(ax5) ax5.grid(True, alpha=PLOT_CONFIG["alpha"]["grid"]) # 保存和显示 fig.suptitle(f"LSH算法可视化 | L={lsh.num_tables}表, 哈希位={lsh.hash_size}, N={n}点", fontsize=PLOT_CONFIG["fontsize"]["suptitle"], fontweight="bold", y=0.98) plt.subplots_adjust(**PLOT_CONFIG["layout"]["subplots_adjust"]) plt.show() speedup = n / len(cands) if len(cands) else float("inf") print(f"\nLSH 结果: 召回率={recall:.0%}, 加速={speedup:.1f}×\n") return recall def demonstrate_lsh(): """演示LSH算法的完整流程""" print("=" * 60) print("LSH算法完整演示") print("=" * 60) # 1. 创建示例数据 np.random.seed(42) n_samples = 20 dim = 2 # 生成具有聚类结构的数据 data = np.vstack([np.random.normal(c, 0.8, (66, 2)) for c in [(2, 2), (-2, -2), (2, -2)]]) print(f"生成{len(data)}个{dim}维数据点") # 2. 创建并初始化LSH索引 lsh = CosineLSH(hash_size=4, num_tables=3) # 3. 向LSH索引中添加数据 lsh.index(data) print("LSH索引构建完成!") # 4. 显示统计信息 print_lsh_statistics(lsh, data) # 5. 创建查询点 query_point = np.array([1.5, 1.5]) print(f"\n查询点: {query_point}") # 6. 执行查询 candidate_ids, topk_ids = lsh.query_with_topk(query_point, k=5) print(f"找到{len(candidate_ids)}个候选向量:") candidate_dists = np.linalg.norm(data[candidate_ids] - query_point, axis=1) for i, (vec_id, dist) in enumerate(zip(candidate_ids, candidate_dists)): print(f"候选向量{vec_id}: 距离={dist:.4f}") if i == 4: print("...") break if len(topk_ids) > 0: print(f"\nTop-5 最近邻结果:") topk_dists = np.linalg.norm(data[topk_ids] - query_point, axis=1) for i, (vec_id, dist) in enumerate(zip(topk_ids, topk_dists)): print(f"Top{i+1}: 向量{vec_id}, 距离={dist:.4f}") # 7. 可视化 print("\n生成可视化图表...") visualize_lsh(lsh, query_point, k=5) return lsh, data, query_point, candidate_ids # 运行演示 lsh, data, query, candidates = demonstrate_lsh()
输出结果:
============================================================ LSH算法完整演示 ============================================================ 生成198个2维数据点 LSH索引构建完成! LSH索引统计: 向量总数: 198 哈希表数量: 3 每表哈希位数: 4 总桶数: 21 表1: 7个桶, 平均每个桶28.29个向量 表2: 8个桶, 平均每个桶24.75个向量 表3: 6个桶, 平均每个桶33.00个向量 查询点: [1.5 1.5] 找到66个候选向量: 候选向量0: 距离=0.9782 候选向量1: 距离=1.9974 候选向量2: 距离=0.4422 候选向量3: 距离=2.0858 候选向量4: 距离=0.9423 ... Top-5 最近邻结果: Top1: 向量42, 距离=0.1768 Top2: 向量5, 距离=0.1815 Top3: 向量46, 距离=0.2457 Top4: 向量51, 距离=0.2667 Top5: 向量14, 距离=0.2674 生成可视化图表... LSH 结果: 召回率=100%, 加速=3.0×
图 1:空间分割基础
展示随机超平面如何将向量空间划分为哈希桶,相似的向量在概率上被映射到相同区域,体现局部敏感哈希的核心机制。
图 2:多表策略
3 个哈希表使用不同超平面并行工作,即使单表(如表3)遗漏相似向量,多表协作也能提高召回率。
图 3:多表共识
颜色深度表示向量被多少个哈希表同时选中的频率,共识度更高的向量与查询点更相似。
图 4:哈希码分析
展示向量如何转换为二进制哈希码,空间相近的向量产生相似的编码模式,验证哈希函数的相似性保持特性。
图 5:性能对比
LSH 近似搜索结果与精确搜索对比,展示在保持高召回率的同时实现显著的搜索加速。
理解LSH算法的参数对掌握其工作原理至关重要:
作用:决定每个哈希表的哈希码长度(位数)
影响:值越大,哈希桶划分越精细,相似度判断越准确,但每个桶内的向量可能越少
作用:控制使用的独立哈希表数量
影响:值越大,找到真正近邻的概率越高,但内存消耗也会增加