赞
踩
最近在做项目的时候,需要比对两个数据库的表结构差异,由于表数量比较多,人工比对的话需要大量时间,且不可复用,于是想到用 python 写一个脚本来达到诉求,下次有相同诉求的时候只需改 sql 文件名即可。
compare_diff.py:
import re import json # 建表语句对象 class TableStmt(object): table_name = "" create_stmt = "" # 表对象 class Table(object): table_name = "" fields = [] indexes = [] # 字段对象 class Field(object): field_name = "" field_type = "" # 索引对象 class Index(object): name = "" type = "" columns = "" # 自定义JSON序列化器,非必须,打印时可用到 def obj_2_dict(obj): if isinstance(obj, Field): return { "field_name": obj.field_name, "field_type": obj.field_type } elif isinstance(obj, Index): return { "name": obj.name, "type": obj.type, "columns": obj.columns } raise TypeError(f"Type {type(obj)} is not serializable") # 正则表达式模式来匹配完整的建表语句 create_table_pattern = re.compile( r"CREATE TABLE `(?P<table_name>\w+)`.*?\)\s*ENGINE[A-Za-z0-9=_ ''\n\r\u4e00-\u9fa5]+;", re.DOTALL | re.IGNORECASE ) # 正则表达式模式来匹配字段名和字段类型,只提取基本类型忽略其他信息 table_pattern = re.compile( r"^\s*`(?P<field>\w+)`\s+(?P<type>[a-zA-Z]+(?:\(\d+(?:,\d+)?\))?)", re.MULTILINE ) # 正则表达式模式来匹配索引定义 index_pattern = re.compile(r'(?<!`)KEY\s+`?(\w+)`?\s*\(([^)]+)\)|' r'PRIMARY\s+KEY\s*\(([^)]+)\)|' r'UNIQUE\s+KEY\s+`?(\w+)`?\s*\(([^)]+)\)|' r'FULLTEXT\s+KEY\s+`?(\w+)`?\s*\(([^)]+)\)', re.IGNORECASE) # 提取每个表名及建表语句 def extract_create_table_statements(sql_script): matches = create_table_pattern.finditer(sql_script) table_create_stmts = [] for match in matches: tableStmt = TableStmt() tableStmt.table_name = match.group('table_name').lower() # 表名统一转换成小写 tableStmt.create_stmt = match.group(0).strip() # 获取匹配到的整个建表语句 table_create_stmts.append(tableStmt) return table_create_stmts # 提取索引 def extract_indexes(sql): matches = index_pattern.findall(sql) indexes = [] for match in matches: index = Index() if match[0]: # 普通索引 index.type = 'index' index.name = match[0].lower() index.columns = match[1].lower() elif match[2]: # 主键 index.type = 'primary key' index.name = 'primary' index.columns = match[2].lower() elif match[3]: # 唯一索引 index.type = 'unique index' index.name = match[3].lower() index.columns = match[4].lower() elif match[5]: # 全文索引 index.type = 'fulltext index' index.name = match[5].lower() index.columns = match[6].lower() indexes.append(index) return indexes # 提取字段 def extract_fields(sql): matches = table_pattern.finditer(sql) fields = [] for match in matches: field = Field() field.field_name = match.group('field').lower() # 字段名统一转换成小写 field.field_type = match.group('type').lower() # 字段类型统一转换小写 fields.append(field) return fields # 提取表信息 def extract_table_info(tableStmt: TableStmt): table = Table() table.table_name = tableStmt.table_name.lower() # 获取字段 table.fields = extract_fields(tableStmt.create_stmt) # 获取索引 table.indexes = extract_indexes(tableStmt.create_stmt) return table # 提取sql脚本中所有的表 def get_all_tables(sql_script): table_map = {} table_stmts = extract_create_table_statements(sql_script) for stmt in table_stmts: table = extract_table_info(stmt) table_map[table.table_name] = table return table_map # 比较两个表的字段 def compare_fields(source: Table, target: Table): source_fields_map = {field.field_name: field for field in source.fields} target_fields_map = {field.field_name: field for field in target.fields} source_fields_not_in_target = [] fields_type_not_match = [] # source表有,而target表没有的字段 for field in source.fields: if field.field_name not in target_fields_map.keys(): source_fields_not_in_target.append(field.field_name) continue target_field = target_fields_map.get(field.field_name) if field.field_type != target_field.field_type: fields_type_not_match.append( "field=" + field.field_name + ", source type: " + field.field_type + ", target type: " + target_field.field_type) target_fields_not_in_source = [] # target表有,而source表没有的字段 for field in target.fields: if field.field_name not in source_fields_map.keys(): target_fields_not_in_source.append(field.field_name) continue # 不用再比较type了,因为如果这个字段在source和target都有的话,前面已经比较过type了 return source_fields_not_in_target, fields_type_not_match, target_fields_not_in_source # 比较两个表的索引 def compare_indexes(source: Table, target: Table): source_indexes_map = {index.name: index for index in source.indexes} target_indexes_map = {index.name: index for index in target.indexes} source_indexes_not_in_target = [] index_column_not_match = [] index_type_not_match = [] for index in source.indexes: if index.name not in target_indexes_map.keys(): # source表有而target表没有的索引 source_indexes_not_in_target.append(index.name) continue target_index = target_indexes_map.get(index.name) # 索引名相同,类型不同 if index.type != target_index.type: index_type_not_match.append( "name=" + index.name + ", source type: " + index.type + ", target type: " + target_index.type) continue # 索引名和类型都相同,字段不同 if index.columns != target_index.columns: index_column_not_match.append( "name=" + index.name + ", source columns=" + index.columns + ", target columns=" + target_index.columns) target_indexes_not_in_source = [] for index in target.indexes: if index.name not in source_indexes_map.keys(): # target表有而source表没有的索引 target_indexes_not_in_source.append(index.name) continue return source_indexes_not_in_target, index_column_not_match, index_type_not_match, target_indexes_not_in_source # 打印比较的结果,如果结果为空列表(说明没有不同)则不打印 def print_diff(desc, compare_result): if len(compare_result) > 0: print(f"{desc} {compare_result}") # 比较脚本里面的所有表 def compare_table(source_sql_script, target_sql_script): source_table_map = get_all_tables(source_sql_script) target_table_map = get_all_tables(target_sql_script) source_table_not_in_target = [] for key, source_table in source_table_map.items(): # 只比较白名单里面的表 if len(white_list_tables) > 0 and key not in white_list_tables: continue # 不比较黑名单里面的表 if len(black_list_tables) > 0 and key in black_list_tables: continue if key not in target_table_map.keys(): # source有而target没有的表 source_table_not_in_target.append(key) continue target_table = target_table_map[key] # 比较字段 (source_fields_not_in_target, fields_type_not_match , target_fields_not_in_source) = compare_fields(source_table, target_table) # 比较索引 (source_indexes_not_in_target, index_column_not_match , index_type_not_match, target_indexes_not_in_source) = compare_indexes(source_table, target_table) print(f"====== table = {key} ======") print_diff("source field not in target, fields:", source_fields_not_in_target) print_diff("target field not in source, fields:", target_fields_not_in_source) print_diff("field type not match:", fields_type_not_match) print_diff("source index not in target, indexes:", source_indexes_not_in_target) print_diff("target index not in source, indexes:", target_indexes_not_in_source) print_diff("index type not match:", index_type_not_match) print_diff("index column not match:", index_column_not_match) print("") # 找出target有而source没有的表 target_table_not_in_source = [] for key, target_table in target_table_map.items(): # 只比较白名单里面的表 if len(white_list_tables) > 0 and key not in white_list_tables: continue # 不比较黑名单里面的表 if len(black_list_tables) > 0 and key in black_list_tables: continue if key not in source_table_map.keys(): target_table_not_in_source.append(key) print_diff("source table not in target, table list:", source_table_not_in_target) print_diff("target table not in source, table list:", target_table_not_in_source) # 读取sql文件 def sql_read(file_name): with open(file_name, "r", encoding='utf-8') as file: return file.read() def print_all_tables(): table_map = get_all_tables(sql_read("sql1.sql")) for key, item in table_map.items(): print(key) print(json.dumps(item.fields, default=obj_2_dict, ensure_ascii=False, indent=4)) print(json.dumps(item.indexes, default=obj_2_dict, ensure_ascii=False, indent=4)) print("") # print_all_tables() # 黑白名单设置,适用于只比较所有表中一部分表的情况 # 白名单表,不为空的话,只比较这里面的表 white_list_tables = [] # 黑名单表,不为空的话,不比较这里面的表 black_list_tables = [] if __name__ == '__main__': # 说明:mysql默认大小写不敏感,如果数据库设置了大小写敏感,脚本需要修改,里面所有的表名、字段名和索引名都默认转了小写再去比较的 source_script = sql_read("sql1.sql") target_script = sql_read("sql2.sql") compare_table(source_script, target_script)
运行效果如下:
====== table = table1 ====== source field not in target, fields: ['age', 'email'] target field not in source, fields: ['name'] field type not match: ['field=created_at, source type: date, target type: bigint(20)', 'field=updated_at, source type: timestamp, target type: date'] source index not in target, indexes: ['unique_name'] target index not in source, indexes: ['idx_country_env'] ====== table = table2 ====== index type not match: ['name=fulltext_index, source type: fulltext index, target type: index'] index column not match: ['name=index, source columns=`age`, target columns=`description`'] ====== table = table3 ====== index column not match: ['name=primary, source columns=`id`, `value`, target columns=`value`, `id`'] source table not in target, table list: ['activity_instance'] target table not in source, table list: ['table5']
结果说明:
sql1.sql:
CREATE TABLE `table1` ( `id` INT(11) NOT NULL AUTO_INCREMENT, `age` INT(11) DEFAULT NULL, `email` varchar(32) DEFAULT NULL COMMENT '邮箱', `created_at` date DEFAULT NULL, `updated_at` TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, PRIMARY KEY (`id`), UNIQUE KEY `unique_name` (`name`) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT ='测试表'; CREATE TABLE `table2` ( `id` INT(11) NOT NULL, `description` TEXT NOT NULL, `created_at` TIMESTAMP DEFAULT CURRENT_TIMESTAMP, PRIMARY KEY (`id`), UNIQUE KEY `unique_name` (`name`), KEY `index` (`age`), FULLTEXT KEY `fulltext_index` (`name`, `age`) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; CREATE TABLE `table3` ( `id` INT(11) NOT NULL AUTO_INCREMENT, `value` DECIMAL(10,2) NOT NULL, `updated_at` TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, PRIMARY KEY (`id`, `value`) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; /******************************************/ /* DatabaseName = database */ /* TableName = activity_instance */ /******************************************/ CREATE TABLE `activity_instance` ( `id` bigint(20) unsigned NOT NULL AUTO_INCREMENT COMMENT '主键', `gmt_create` bigint(20) NOT NULL COMMENT '创建时间', `gmt_modified` bigint(20) NOT NULL COMMENT '修改时间', `activity_name` varchar(400) NOT NULL COMMENT '活动名称', `benefit_type` varchar(16) DEFAULT NULL, `benefit_id` varchar(32) DEFAULT NULL, PRIMARY KEY (`id`), KEY `idx_country_env` (`env`, `country_code`), KEY `idx_benefit_type_id` (`benefit_type`, `benefit_id`) ) ENGINE = InnoDB AUTO_INCREMENT = 139 DEFAULT CHARSET = utf8mb4 COMMENT ='活动时间模板表' ;
sql2.sql:
CREATE TABLE `TABLE1` ( `id` INT(11) NOT NULL AUTO_INCREMENT, `name` VARCHAR(255) NOT NULL, `created_at` bigint(20) DEFAULT NULL, `updated_at` date ON UPDATE CURRENT_TIMESTAMP, PRIMARY KEY (`id`), KEY `idx_country_env` (`env`, `country_code`), ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT ='测试表'; CREATE TABLE `table2` ( `id` INT(11) NOT NULL, `description` TEXT NOT NULL, `created_at` TIMESTAMP DEFAULT CURRENT_TIMESTAMP, PRIMARY KEY (`id`), UNIQUE KEY `unique_name` (`name`), KEY `index` (`description`), KEY `fulltext_index` (`name`, `age`) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; CREATE TABLE `table3` ( `id` INT(11) NOT NULL AUTO_INCREMENT, `value` DECIMAL(10,2) NOT NULL, `updated_at` TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, PRIMARY KEY (`value`, `id`) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; CREATE TABLE `TABLE5` ( `id` INT(11) NOT NULL AUTO_INCREMENT, `value` DECIMAL(10,2) NOT NULL, `updated_at` TIMESTAMP ON UPDATE CURRENT_TIMESTAMP ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
把 python 和 sql 脚本拷贝下来分别放在同一个目录下的3个文件中即可,示例在 python 3.12 环境上成功运行。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。