前一段时间迁移老系统到新系统的功能改造的时候,需要统计一下这个模块用到的所有的数据库表用来放在文档里
代码里DAO层用字符串字面量存放了查询SQL语句,SQL语句的数据库表名只会出现在FROM和JOIN关键词之后,
我搜了一下FROM有1000个,起码有1000多个表名出现在代码里,一个一个手动记录估计要半天。决定花一两个小时写个脚本统计一下
首先是匹配逻辑,基本上表都是出现在FROM和JOIN之后,
FROM[空格]table[空格]where/group/order...
JOIN[空格]Table[空格]ON
子查询里面也是这样,只要按照FOMR、空格、表名、空格、终止符。正则表达式用这个模式匹配扫一遍所有的SQL语句就可以了
但是还有其他边界情况
第一个是隐式连接由逗号隔开的多表
SELECT * FROM TABLE_A a, TABLE_B b WHERE a.id = b.id
FROM 后面需要按照逗号拆分
第二个是 UPDATE、INSERT、DELETE语句和SELECT不同
UPDATE table SET ... → 表名在 UPDATE 后面
INSERT INTO table ... → 表名在 INTO 后面
DELETE FROM table ... → FROM 已经覆盖了,不需要额外处理
第三个是 Oracle CTE临时表和DUAL伪表
情况都考虑完了就让Claude生成脚本了,
用法
python extract_tables.py /path/to/your/htdm/project -o my_report
脚本有两个参数,第一个是需要统计代码的路径(支持递归),第二个是导出的文件名(保存成csv表格和txt文本两个文件)
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
SQL表名提取工具
================
用途:扫描Java项目源码中以字符串形式拼接出现的SQL语句(SELECT查询),
提取其中引用的表名(保留schema前缀),统计出现次数及来源文件。
核心逻辑(单层全局扫描):
1. 从Java源文件提取所有字符串字面量,按顺序拼接成一份语料
2. 在语料上全局查找所有 FROM/JOIN 关键字
3. 每个 FROM/JOIN 后面可能跟一个表名,也可能跟逗号分隔的多个表名(老式多表写法)
4. 对每个候选 token:
- 跳过以 "(" 开头的(子查询起点,其内部的FROM会被全局扫描自然覆盖,不需要递归)
- 黑名单过滤(常见SQL函数/关键字)
- 左括号后缀校验(token后面紧跟"("说明是函数调用)
5. 子查询不做递归处理:子查询内部的 FROM/JOIN 会被同一轮全局 finditer 自然覆盖,
结果与递归等价,但代码逻辑简单得多,也更贴近人工核查的直觉
跨字符串拼接:整个文件所有字符串字面量按顺序拼接成一份语料,
保证不因SQL被拆成多段字符串而漏抓表名。代价是极少数情况下可能把
毫不相关的两段字符串拼接到一起产生误报,可用来源文件列表定位复查。
大小写:所有表名统一转大写后去重统计。
容错:单个文件读取失败时跳过并打印警告,不中断整体扫描。
"""
import argparse
import csv
import os
import re
import sys
from collections import defaultdict
# ============================================================
# 配置区
# ============================================================
# 排除的伪表(不计入统计)
EXCLUDED_TABLES = {"DUAL"}
# 黑名单:常见SQL关键字/聚合函数/内置函数名。
# 应对跨字符串拼接导致 FROM 后面意外接上函数调用的情况,比如
# "...FROM " + "ROUND(a.value,2)..." 拼接后变成 "FROM ROUND(...)"。
SQL_BLACKLIST = {
"SUM", "COUNT", "AVG", "MAX", "MIN",
"ROUND", "TRUNC", "NVL", "NVL2", "DECODE", "CASE", "WHEN", "THEN", "ELSE", "END",
"TO_CHAR", "TO_DATE", "TO_NUMBER", "SUBSTR", "INSTR", "LENGTH", "LPAD", "RPAD",
"LISTAGG", "WM_CONCAT", "ROW_NUMBER", "RANK", "DENSE_RANK", "LAG", "LEAD",
"SYSDATE", "SYSTIMESTAMP", "EXTRACT", "REPLACE", "TRIM", "UPPER", "LOWER",
"COALESCE", "GREATEST", "LEAST", "ABS", "MOD", "POWER", "SQRT", "CAST",
"SELECT", "DISTINCT", "AS", "ON", "AND", "OR", "NOT", "NULL", "IS",
"IN", "EXISTS", "BETWEEN", "LIKE", "ALL", "ANY", "SOME",
}
# Java字符串字面量:匹配双引号包裹的内容,处理转义
JAVA_STRING_LITERAL_RE = re.compile(r'"((?:[^"\\]|\\.)*)"')
# 标识符:字母/$/_开头,后跟字母/数字/下划线/$/#,可带一个 schema. 前缀
IDENT = r"[A-Za-z_$][A-Za-z0-9_$#]*"
# FROM/JOIN/UPDATE/INSERT INTO 后面跟着的"表名列表":
# - FROM/JOIN(含各种前缀)、UPDATE、INSERT INTO
# - 后面至少一个空白
# - 然后是"表名列表"——捕获到第一个终止符为止
#
# 终止符新增:SET(UPDATE table SET ... 的分隔点)
_KW = (
r"(?:INNER\s+|LEFT\s+(?:OUTER\s+)?|RIGHT\s+(?:OUTER\s+)?|FULL\s+(?:OUTER\s+)?|CROSS\s+)?JOIN"
r"|FROM|UPDATE|INSERT\s+INTO"
)
_STOP = (
r"\bWHERE\b|\bSET\b|\bGROUP\s+BY\b|\bORDER\s+BY\b|\bHAVING\b"
r"|\bSTART\s+WITH\b|\bCONNECT\s+BY\b"
r"|\bUNION\b|\bMINUS\b|\bINTERSECT\b"
r"|\b(?:INNER\s+|LEFT\s+(?:OUTER\s+)?|RIGHT\s+(?:OUTER\s+)?|FULL\s+(?:OUTER\s+)?|CROSS\s+)?JOIN\b"
r"|\bON\b"
r"|\(" # 遇到左括号就停:让子查询内部的 FROM 被 finditer 后续单独捕获,
# 而不是被当前匹配"吞进去"后因整体以"("开头而跳过
r"|$"
)
FROM_JOIN_RE = re.compile(
rf"\b(?:{_KW})\s+((?:(?!{_STOP}).)+)",
re.IGNORECASE,
)
# 单个表名 token:从字符串开头匹配,可带 schema. 前缀
TOKEN_RE = re.compile(rf"^\s*({IDENT}(?:\.{IDENT})?)")
# CTE 别名识别:WITH name AS ( ...
# 只捕获别名本身,不需要找到匹配的右括号(别名识别到 AS 就够了)
CTE_WITH_RE = re.compile(r"\bWITH\b", re.IGNORECASE)
CTE_NAME_RE = re.compile(rf"\b({IDENT})\s+AS\s*\(", re.IGNORECASE)
# ============================================================
# 核心逻辑
# ============================================================
def collect_cte_names(corpus):
"""
扫描语料中所有 WITH name AS (...) 定义,返回别名的大写集合。
实现:从每个 WITH 关键字开始,向后连续匹配 "name AS (" 模式,
找到一个 CTE 别名后,跳过其括号内部(用括号计数器),
再继续找下一个(逗号分隔),直到不再符合模式为止。
"""
cte_names = set()
for wm in CTE_WITH_RE.finditer(corpus):
pos = wm.end()
n = len(corpus)
while pos < n:
# 跳过空白和逗号(多个 CTE 之间用逗号分隔)
while pos < n and corpus[pos] in " \t\r\n,":
pos += 1
m = CTE_NAME_RE.match(corpus, pos)
if not m:
break # 不再是 "name AS (" 模式,WITH 的 CTE 列表结束
cte_names.add(m.group(1).upper())
# m.end() 指向 "(" 之后第一个字符,需要从 "(" 开始数括号
# 找到这个 AS 后面的 "(" 的位置(m.end()-1 就是那个左括号)
paren_pos = m.end() - 1 # CTE_NAME_RE 以 \( 结尾,m.end()-1 是 "("
depth = 1
pos = paren_pos + 1
while pos < n and depth > 0:
if corpus[pos] == "(":
depth += 1
elif corpus[pos] == ")":
depth -= 1
pos += 1
# pos 现在指向匹配的右括号之后,继续找下一个 CTE 定义
return cte_names
def extract_tables_from_corpus(corpus, cte_names=None):
"""
从拼接好的SQL语料中提取所有表名(原始大小写,调用方负责转大写)。
步骤:
1. 先收集语料中所有 CTE 别名(WITH name AS (...)),提取时过滤掉
2. 全局匹配所有 FROM/JOIN/UPDATE/INSERT INTO 子句片段
3. 每个片段按顶层逗号拆分(处理 FROM A, B, C 的多表写法)
4. 每个 token 依次过滤:子查询括号 / 黑名单 / CTE别名 / 左括号后缀
关于左括号后缀校验的实现方式:
终止符包含了 "("(为了让子查询内部的 FROM 被 finditer 后续单独捕获),
所以 clause 在遇到 "(" 时就已经截断,token 里不会带后缀括号。
因此不能看 token 自身的后缀,而要回到 corpus 里,
找到该 token 在原文中的起始位置,检查其后面紧接的字符是否是 "("。
"""
if cte_names is None:
cte_names = collect_cte_names(corpus)
results = []
for m in FROM_JOIN_RE.finditer(corpus):
clause = m.group(1)
clause_start = m.start(1) # clause 在 corpus 里的起始偏移,用于定位 token 的原始位置
# INSERT INTO table (col,...) VALUES ... 里表名后面紧跟的 "(" 是列名列表,
# 不是函数调用,对这个关键字跳过左括号后缀校验
is_insert = m.group(0).upper().startswith("INSERT")
tokens = _split_top_level_comma(clause)
offset = 0 # 在 clause 内的当前偏移,用于计算每个 token 在 corpus 里的位置
for token in tokens:
token_len = len(token)
token_stripped = token.strip()
if token_stripped:
if token_stripped.startswith("("):
# 子查询起点,其内部 FROM 会被 finditer 后续单独捕获
pass
else:
tm = TOKEN_RE.match(token_stripped)
if tm:
name = tm.group(1)
# 黑名单、CTE别名、伪表过滤
if (name.upper() not in SQL_BLACKLIST
and name.upper() not in EXCLUDED_TABLES
and name.upper() not in cte_names):
# 左括号后缀校验:在 corpus 原文里找 token 名称后面的字符
# token 在 corpus 里的起始 = clause_start + offset + 前导空白数
leading_spaces = len(token) - len(token.lstrip())
name_end_in_corpus = clause_start + offset + leading_spaces + len(name)
# 跳过 corpus 里紧跟 name 的空白,看第一个非空白字符
pos = name_end_in_corpus
while pos < len(corpus) and corpus[pos] in " \t\r\n":
pos += 1
if not is_insert and pos < len(corpus) and corpus[pos] == "(": pass # 函数调用,跳过(INSERT INTO 时不做此校验) else: results.append(name) offset += token_len + 1 # +1 是逗号分隔符 return results def _split_top_level_comma(text): """ 按顶层逗号拆分(括号内的逗号不算分隔符)。 用于处理 FROM A, B, C 这种多表写法。 """ parts = [] depth = 0 start = 0 for i, ch in enumerate(text): if ch == "(": depth += 1 elif ch == ")": depth -= 1 elif ch == "," and depth == 0: parts.append(text[start:i]) start = i + 1 parts.append(text[start:]) return parts def extract_string_literals(java_source): """ 从Java源码中提取所有字符串字面量,拼接成一份语料。 用不可见的换行作分隔,避免相邻字符串首尾字符意外拼成新词。 """ literals = JAVA_STRING_LITERAL_RE.findall(java_source) unescaped = [lit.replace('\\"', '"').replace("\\\\", "\\") for lit in literals] return "\n".join(unescaped) # ============================================================ # 文件扫描与统计 # ============================================================ def scan_file(filepath): """读取单个.java文件,返回提取到的表名列表(已大写化)。""" for enc in ("utf-8", "gbk", "gb2312", "utf-16"): try: with open(filepath, encoding=enc) as f: content = f.read() break except (UnicodeDecodeError, UnicodeError): continue except OSError as e: print(f"[警告] 读取文件失败,已跳过: {filepath} ({e})", file=sys.stderr) return [] else: print(f"[警告] 无法以任何编码读取文件,已跳过: {filepath}", file=sys.stderr) return [] try: corpus = extract_string_literals(content) return [t.upper() for t in extract_tables_from_corpus(corpus)] except Exception as e: print(f"[警告] 解析文件时出错,已跳过: {filepath} ({e})", file=sys.stderr) return [] def scan_directory(root_dir): """递归扫描目录下所有.java文件,返回表名→文件集合、表名→出现次数。""" table_to_files = defaultdict(set) table_to_count = defaultdict(int) java_files = [ os.path.join(dp, f) for dp, _, files in os.walk(root_dir) for f in files if f.lower().endswith(".java") ] if not java_files: print(f"[提示] 在目录 {root_dir} 下未找到任何 .java 文件", file=sys.stderr) for filepath in java_files: for table in scan_file(filepath): table_to_files[table].add(filepath) table_to_count[table] += 1 return table_to_files, table_to_count # ============================================================ # 输出 # ============================================================ def print_summary(table_to_count, table_to_files): if not table_to_count: print("未提取到任何表名。") return rows = sorted(table_to_count.items(), key=lambda kv: (-kv[1], kv[0])) w = max(len(n) for n, _ in rows) w = max(w, len("表名")) print() print(f"{'表名'.ljust(w)} 出现次数 涉及文件数") print("-" * (w + 24)) for name, count in rows: print(f"{name.ljust(w)} {count:>6} {len(table_to_files[name]):>8}")
print("-" * (w + 24))
print(f"共计 {len(rows)} 个不同表名,{sum(table_to_count.values())} 次出现")
print()
def write_txt(table_to_count, table_to_files, path):
rows = sorted(table_to_count.items(), key=lambda kv: (-kv[1], kv[0]))
with open(path, "w", encoding="utf-8") as f:
f.write("SQL表名统计结果\n" + "=" * 60 + "\n\n")
for name, count in rows:
files = sorted(table_to_files[name])
f.write(f"表名: {name}\n出现次数: {count}\n涉及文件数: {len(files)}\n来源文件:\n")
for fp in files:
f.write(f" - {fp}\n")
f.write("\n")
f.write("=" * 60 + f"\n共计 {len(rows)} 个不同表名,{sum(table_to_count.values())} 次出现\n")
def write_csv(table_to_count, table_to_files, path):
rows = sorted(table_to_count.items(), key=lambda kv: (-kv[1], kv[0]))
with open(path, "w", encoding="utf-8-sig", newline="") as f:
w = csv.writer(f)
w.writerow(["表名", "出现次数", "涉及文件数", "来源文件列表"])
for name, count in rows:
files = sorted(table_to_files[name])
w.writerow([name, count, len(files), "; ".join(files)])
# ============================================================
# 入口
# ============================================================
def main():
parser = argparse.ArgumentParser(
description="扫描Java项目源码中以字符串拼接形式出现的SQL,统计SELECT查询涉及的表名。"
)
parser.add_argument("directory", help="待扫描的项目根目录(递归扫描所有.java文件)")
parser.add_argument("-o", "--output-prefix", default="sql_table_report",
help="输出文件名前缀(默认: sql_table_report)")
args = parser.parse_args()
if not os.path.isdir(args.directory):
print(f"[错误] 目录不存在: {args.directory}", file=sys.stderr)
sys.exit(1)
print(f"开始扫描目录: {args.directory}")
table_to_files, table_to_count = scan_directory(args.directory)
print_summary(table_to_count, table_to_files)
txt_path = f"{args.output_prefix}.txt"
csv_path = f"{args.output_prefix}.csv"
write_txt(table_to_count, table_to_files, txt_path)
write_csv(table_to_count, table_to_files, csv_path)
print(f"结果已保存:\n - {os.path.abspath(txt_path)}\n - {os.path.abspath(csv_path)}")
if __name__ == "__main__":
main()
文章评论