当前位置:   article > 正文

python 比较 mysql 表结构差异

python 比较 mysql 表结构差异

最近在做项目的时候,需要比对两个数据库的表结构差异,由于表数量比较多,人工比对的话需要大量时间,且不可复用,于是想到用 python 写一个脚本来达到诉求,下次有相同诉求的时候只需改 sql 文件名即可。

compare_diff.py:

import re
import json


# 建表语句对象
class TableStmt(object):
    table_name = ""
    create_stmt = ""


# 表对象
class Table(object):
    table_name = ""
    fields = []
    indexes = []


# 字段对象
class Field(object):
    field_name = ""
    field_type = ""


# 索引对象
class Index(object):
    name = ""
    type = ""
    columns = ""


# 自定义JSON序列化器,非必须,打印时可用到
def obj_2_dict(obj):
    if isinstance(obj, Field):
        return {
            "field_name": obj.field_name,
            "field_type": obj.field_type
        }
    elif isinstance(obj, Index):
        return {
            "name": obj.name,
            "type": obj.type,
            "columns": obj.columns
        }
    raise TypeError(f"Type {type(obj)} is not serializable")


# 正则表达式模式来匹配完整的建表语句
create_table_pattern = re.compile(
    r"CREATE TABLE `(?P<table_name>\w+)`.*?\)\s*ENGINE[A-Za-z0-9=_ ''\n\r\u4e00-\u9fa5]+;",
    re.DOTALL | re.IGNORECASE
)

# 正则表达式模式来匹配字段名和字段类型,只提取基本类型忽略其他信息
table_pattern = re.compile(
    r"^\s*`(?P<field>\w+)`\s+(?P<type>[a-zA-Z]+(?:\(\d+(?:,\d+)?\))?)",
    re.MULTILINE
)

# 正则表达式模式来匹配索引定义
index_pattern = re.compile(r'(?<!`)KEY\s+`?(\w+)`?\s*\(([^)]+)\)|'
                           r'PRIMARY\s+KEY\s*\(([^)]+)\)|'
                           r'UNIQUE\s+KEY\s+`?(\w+)`?\s*\(([^)]+)\)|'
                           r'FULLTEXT\s+KEY\s+`?(\w+)`?\s*\(([^)]+)\)',
                           re.IGNORECASE)


# 提取每个表名及建表语句
def extract_create_table_statements(sql_script):
    matches = create_table_pattern.finditer(sql_script)
    table_create_stmts = []
    for match in matches:
        tableStmt = TableStmt()
        tableStmt.table_name = match.group('table_name').lower()  # 表名统一转换成小写
        tableStmt.create_stmt = match.group(0).strip()  # 获取匹配到的整个建表语句
        table_create_stmts.append(tableStmt)
    return table_create_stmts


# 提取索引
def extract_indexes(sql):
    matches = index_pattern.findall(sql)
    indexes = []
    for match in matches:
        index = Index()
        if match[0]:  # 普通索引
            index.type = 'index'
            index.name = match[0].lower()
            index.columns = match[1].lower()
        elif match[2]:  # 主键
            index.type = 'primary key'
            index.name = 'primary'
            index.columns = match[2].lower()
        elif match[3]:  # 唯一索引
            index.type = 'unique index'
            index.name = match[3].lower()
            index.columns = match[4].lower()
        elif match[5]:  # 全文索引
            index.type = 'fulltext index'
            index.name = match[5].lower()
            index.columns = match[6].lower()
        indexes.append(index)
    return indexes


# 提取字段
def extract_fields(sql):
    matches = table_pattern.finditer(sql)
    fields = []
    for match in matches:
        field = Field()
        field.field_name = match.group('field').lower()  # 字段名统一转换成小写
        field.field_type = match.group('type').lower()  # 字段类型统一转换小写
        fields.append(field)
    return fields


# 提取表信息
def extract_table_info(tableStmt: TableStmt):
    table = Table()
    table.table_name = tableStmt.table_name.lower()
    # 获取字段
    table.fields = extract_fields(tableStmt.create_stmt)
    # 获取索引
    table.indexes = extract_indexes(tableStmt.create_stmt)
    return table


# 提取sql脚本中所有的表
def get_all_tables(sql_script):
    table_map = {}
    table_stmts = extract_create_table_statements(sql_script)
    for stmt in table_stmts:
        table = extract_table_info(stmt)
        table_map[table.table_name] = table
    return table_map


# 比较两个表的字段
def compare_fields(source: Table, target: Table):
    source_fields_map = {field.field_name: field for field in source.fields}
    target_fields_map = {field.field_name: field for field in target.fields}

    source_fields_not_in_target = []
    fields_type_not_match = []
    #  source表有,而target表没有的字段
    for field in source.fields:
        if field.field_name not in target_fields_map.keys():
            source_fields_not_in_target.append(field.field_name)
            continue

        target_field = target_fields_map.get(field.field_name)
        if field.field_type != target_field.field_type:
            fields_type_not_match.append(
                "field=" + field.field_name + ", source type: " + field.field_type + ", target type: " + target_field.field_type)

    target_fields_not_in_source = []
    #  target表有,而source表没有的字段
    for field in target.fields:
        if field.field_name not in source_fields_map.keys():
            target_fields_not_in_source.append(field.field_name)
            continue
        # 不用再比较type了,因为如果这个字段在source和target都有的话,前面已经比较过type了

    return source_fields_not_in_target, fields_type_not_match, target_fields_not_in_source


# 比较两个表的索引
def compare_indexes(source: Table, target: Table):
    source_indexes_map = {index.name: index for index in source.indexes}
    target_indexes_map = {index.name: index for index in target.indexes}

    source_indexes_not_in_target = []
    index_column_not_match = []
    index_type_not_match = []
    for index in source.indexes:
        if index.name not in target_indexes_map.keys():
            # source表有而target表没有的索引
            source_indexes_not_in_target.append(index.name)
            continue

        target_index = target_indexes_map.get(index.name)
        # 索引名相同,类型不同
        if index.type != target_index.type:
            index_type_not_match.append(
                "name=" + index.name + ", source type: " + index.type + ", target type: " + target_index.type)
            continue

        # 索引名和类型都相同,字段不同
        if index.columns != target_index.columns:
            index_column_not_match.append(
                "name=" + index.name + ", source columns=" + index.columns + ", target columns=" + target_index.columns)

    target_indexes_not_in_source = []
    for index in target.indexes:
        if index.name not in source_indexes_map.keys():
            # target表有而source表没有的索引
            target_indexes_not_in_source.append(index.name)
            continue

    return source_indexes_not_in_target, index_column_not_match, index_type_not_match, target_indexes_not_in_source


# 打印比较的结果,如果结果为空列表(说明没有不同)则不打印
def print_diff(desc, compare_result):
    if len(compare_result) > 0:
        print(f"{desc} {compare_result}")


# 比较脚本里面的所有表
def compare_table(source_sql_script, target_sql_script):
    source_table_map = get_all_tables(source_sql_script)
    target_table_map = get_all_tables(target_sql_script)

    source_table_not_in_target = []
    for key, source_table in source_table_map.items():
        # 只比较白名单里面的表
        if len(white_list_tables) > 0 and key not in white_list_tables:
            continue

        # 不比较黑名单里面的表
        if len(black_list_tables) > 0 and key in black_list_tables:
            continue

        if key not in target_table_map.keys():
            # source有而target没有的表
            source_table_not_in_target.append(key)
            continue

        target_table = target_table_map[key]
        # 比较字段
        (source_fields_not_in_target, fields_type_not_match
         , target_fields_not_in_source) = compare_fields(source_table, target_table)

        # 比较索引
        (source_indexes_not_in_target, index_column_not_match
         , index_type_not_match, target_indexes_not_in_source) = compare_indexes(source_table, target_table)

        print(f"====== table = {key} ======")
        print_diff("source field not in target, fields:", source_fields_not_in_target)
        print_diff("target field not in source, fields:", target_fields_not_in_source)
        print_diff("field type not match:", fields_type_not_match)
        print_diff("source index not in target, indexes:", source_indexes_not_in_target)
        print_diff("target index not in source, indexes:", target_indexes_not_in_source)
        print_diff("index type not match:", index_type_not_match)
        print_diff("index column not match:", index_column_not_match)
        print("")

    # 找出target有而source没有的表
    target_table_not_in_source = []
    for key, target_table in target_table_map.items():
        # 只比较白名单里面的表
        if len(white_list_tables) > 0 and key not in white_list_tables:
            continue

        # 不比较黑名单里面的表
        if len(black_list_tables) > 0 and key in black_list_tables:
            continue

        if key not in source_table_map.keys():
            target_table_not_in_source.append(key)

    print_diff("source table not in target, table list:", source_table_not_in_target)
    print_diff("target table not in source, table list:", target_table_not_in_source)


# 读取sql文件
def sql_read(file_name):
    with open(file_name, "r", encoding='utf-8') as file:
        return file.read()


def print_all_tables():
    table_map = get_all_tables(sql_read("sql1.sql"))
    for key, item in table_map.items():
        print(key)
        print(json.dumps(item.fields, default=obj_2_dict, ensure_ascii=False, indent=4))
        print(json.dumps(item.indexes, default=obj_2_dict, ensure_ascii=False, indent=4))
        print("")


# print_all_tables()

# 黑白名单设置,适用于只比较所有表中一部分表的情况
# 白名单表,不为空的话,只比较这里面的表
white_list_tables = []
# 黑名单表,不为空的话,不比较这里面的表
black_list_tables = []

if __name__ == '__main__':
    # 说明:mysql默认大小写不敏感,如果数据库设置了大小写敏感,脚本需要修改,里面所有的表名、字段名和索引名都默认转了小写再去比较的
    source_script = sql_read("sql1.sql")
    target_script = sql_read("sql2.sql")
    compare_table(source_script, target_script)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228
  • 229
  • 230
  • 231
  • 232
  • 233
  • 234
  • 235
  • 236
  • 237
  • 238
  • 239
  • 240
  • 241
  • 242
  • 243
  • 244
  • 245
  • 246
  • 247
  • 248
  • 249
  • 250
  • 251
  • 252
  • 253
  • 254
  • 255
  • 256
  • 257
  • 258
  • 259
  • 260
  • 261
  • 262
  • 263
  • 264
  • 265
  • 266
  • 267
  • 268
  • 269
  • 270
  • 271
  • 272
  • 273
  • 274
  • 275
  • 276
  • 277
  • 278
  • 279
  • 280
  • 281
  • 282
  • 283
  • 284
  • 285
  • 286
  • 287
  • 288
  • 289
  • 290
  • 291
  • 292
  • 293
  • 294

运行效果如下:

====== table = table1 ======
source field not in target, fields: ['age', 'email']
target field not in source, fields: ['name']
field type not match: ['field=created_at, source type: date, target type: bigint(20)', 'field=updated_at, source type: timestamp, target type: date']
source index not in target, indexes: ['unique_name']
target index not in source, indexes: ['idx_country_env']

====== table = table2 ======
index type not match: ['name=fulltext_index, source type: fulltext index, target type: index']
index column not match: ['name=index, source columns=`age`, target columns=`description`']

====== table = table3 ======
index column not match: ['name=primary, source columns=`id`, `value`, target columns=`value`, `id`']

source table not in target, table list: ['activity_instance']
target table not in source, table list: ['table5']
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16

结果说明:

  • 按照 table 来打印 source table 和 target table 的字段和索引差异,此时 table 在两个 sql 脚本里都存在
  • 最后打印只在其中一个 sql 脚本里存在的 table list

sql1.sql:

CREATE TABLE `table1` (
  `id` INT(11) NOT NULL AUTO_INCREMENT,
  `age` INT(11) DEFAULT NULL,
  `email` varchar(32)   DEFAULT NULL COMMENT '邮箱',
  `created_at` date DEFAULT NULL,
  `updated_at` TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
  PRIMARY KEY (`id`),
  UNIQUE KEY `unique_name` (`name`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT ='测试表';

CREATE TABLE `table2` (
  `id` INT(11) NOT NULL,
  `description` TEXT NOT NULL,
  `created_at` TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
  PRIMARY KEY (`id`),
  UNIQUE KEY `unique_name` (`name`),
  KEY `index` (`age`),
  FULLTEXT KEY `fulltext_index` (`name`, `age`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;

CREATE TABLE `table3` (
  `id` INT(11) NOT NULL AUTO_INCREMENT,
  `value` DECIMAL(10,2) NOT NULL,
  `updated_at` TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
  PRIMARY KEY (`id`, `value`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;

/******************************************/
/*   DatabaseName = database   */
/*   TableName = activity_instance   */
/******************************************/
CREATE TABLE `activity_instance`
(
    `id`                   bigint(20) unsigned NOT NULL AUTO_INCREMENT COMMENT '主键',
    `gmt_create`           bigint(20) NOT NULL COMMENT '创建时间',
    `gmt_modified`         bigint(20) NOT NULL COMMENT '修改时间',
    `activity_name`        varchar(400)  NOT NULL COMMENT '活动名称',
    `benefit_type`         varchar(16)   DEFAULT NULL,
    `benefit_id`           varchar(32)   DEFAULT NULL,
    PRIMARY KEY (`id`),
    KEY `idx_country_env` (`env`, `country_code`),
    KEY `idx_benefit_type_id` (`benefit_type`, `benefit_id`)
) ENGINE = InnoDB
  AUTO_INCREMENT = 139
  DEFAULT CHARSET = utf8mb4 COMMENT ='活动时间模板表'
;
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46

sql2.sql:

CREATE TABLE `TABLE1` (
  `id` INT(11) NOT NULL AUTO_INCREMENT,
  `name` VARCHAR(255) NOT NULL,
  `created_at` bigint(20) DEFAULT NULL,
  `updated_at` date ON UPDATE CURRENT_TIMESTAMP,
  PRIMARY KEY (`id`),
  KEY `idx_country_env` (`env`, `country_code`),
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT ='测试表';

CREATE TABLE `table2` (
  `id` INT(11) NOT NULL,
  `description` TEXT NOT NULL,
  `created_at` TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
  PRIMARY KEY (`id`),
  UNIQUE KEY `unique_name` (`name`),
  KEY `index` (`description`),
  KEY `fulltext_index` (`name`, `age`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;

CREATE TABLE `table3` (
  `id` INT(11) NOT NULL AUTO_INCREMENT,
  `value` DECIMAL(10,2) NOT NULL,
  `updated_at` TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
  PRIMARY KEY (`value`, `id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;

CREATE TABLE `TABLE5` (
  `id` INT(11) NOT NULL AUTO_INCREMENT,
  `value` DECIMAL(10,2) NOT NULL,
  `updated_at` TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31

把 python 和 sql 脚本拷贝下来分别放在同一个目录下的3个文件中即可,示例在 python 3.12 环境上成功运行。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/AllinToyou/article/detail/690777
推荐阅读
相关标签
  

闽ICP备14008679号