|
|
@@ -571,7 +571,7 @@ GO
|
|
|
return "\n".join(f"{script}\nGO" for script in self.index(ddl))
|
|
|
|
|
|
def gen_insert(self, table_name: str) -> str:
|
|
|
- """生成 insert 语句,以及根据最后的 insert id+1 生成 Sequence"""
|
|
|
+ """生成 insert 语句"""
|
|
|
|
|
|
# 收集 `table_name` 对应的 insert 语句
|
|
|
inserts = []
|
|
|
@@ -606,13 +606,117 @@ GO
|
|
|
return script
|
|
|
|
|
|
|
|
|
+class DM8Convertor(Convertor):
|
|
|
+ def __init__(self, src):
|
|
|
+ super().__init__(src, "DM8")
|
|
|
+
|
|
|
+ def translate_type(self, type: str, size: Optional[Union[int, Tuple[int]]]):
|
|
|
+ """类型转换"""
|
|
|
+ type = type.lower()
|
|
|
+
|
|
|
+ if type == "varchar":
|
|
|
+ return f"varchar({size})"
|
|
|
+ if type == "int":
|
|
|
+ return "int"
|
|
|
+ if type == "bigint" or type == "bigint unsigned":
|
|
|
+ return "bigint"
|
|
|
+ if type == "datetime":
|
|
|
+ return "datetime"
|
|
|
+ if type == "bit":
|
|
|
+ return "bit"
|
|
|
+ if type in ("tinyint", "smallint"):
|
|
|
+ return "smallint"
|
|
|
+ if type == "text":
|
|
|
+ return "text"
|
|
|
+ if type == "blob":
|
|
|
+ return "blob"
|
|
|
+ if type == "mediumblob":
|
|
|
+ return "varchar(10240)"
|
|
|
+ if type == "decimal":
|
|
|
+ return (
|
|
|
+ f"decimal({','.join(str(s) for s in size)})" if len(size) else "decimal"
|
|
|
+ )
|
|
|
+
|
|
|
+ def gen_create(self, ddl) -> str:
|
|
|
+ """生成 CREATE 语句"""
|
|
|
+
|
|
|
+ def generate_column(col):
|
|
|
+ name = col["name"].lower()
|
|
|
+ if name == "id":
|
|
|
+ return "id bigint NOT NULL PRIMARY KEY IDENTITY"
|
|
|
+
|
|
|
+ type = col["type"].lower()
|
|
|
+ full_type = self.translate_type(type, col["size"])
|
|
|
+ nullable = "NULL" if col["nullable"] else "NOT NULL"
|
|
|
+ default = f"DEFAULT {col['default']}" if col["default"] is not None else ""
|
|
|
+ return f"{name} {full_type} {default} {nullable}"
|
|
|
+
|
|
|
+ table_name = ddl["table_name"].lower()
|
|
|
+ columns = [f"{generate_column(col).strip()}" for col in ddl["columns"]]
|
|
|
+ field_def_list = ",\n ".join(columns)
|
|
|
+ script = f"""-- ----------------------------
|
|
|
+-- Table structure for {table_name}
|
|
|
+-- ----------------------------
|
|
|
+CREATE TABLE {table_name} (
|
|
|
+ {field_def_list}
|
|
|
+);"""
|
|
|
+
|
|
|
+ # oracle INSERT '' 不能通过 NOT NULL 校验
|
|
|
+ script = script.replace("DEFAULT '' NOT NULL", "DEFAULT '' NULL")
|
|
|
+
|
|
|
+ return script
|
|
|
+
|
|
|
+ def gen_index(self, ddl: Dict) -> str:
|
|
|
+ return "\n".join(f"{script};" for script in self.index(ddl))
|
|
|
+
|
|
|
+ def gen_comment(self, table_sql: str, table_name: str) -> str:
|
|
|
+ script = ""
|
|
|
+ for field, comment_string in self.filed_comments(table_sql):
|
|
|
+ script += (
|
|
|
+ f"COMMENT ON COLUMN {table_name}.{field} IS '{comment_string}';" + "\n"
|
|
|
+ )
|
|
|
+
|
|
|
+ table_comment = self.table_comment(table_sql)
|
|
|
+ if table_comment:
|
|
|
+ script += f"COMMENT ON TABLE {table_name} IS '{table_comment}';\n"
|
|
|
+
|
|
|
+ return script
|
|
|
+
|
|
|
+ def gen_pk(self, table_name: str) -> str:
|
|
|
+ """生成主键定义"""
|
|
|
+ return ""
|
|
|
+
|
|
|
+ def gen_index(self, ddl: Dict) -> str:
|
|
|
+ return "\n".join(f"{script};" for script in self.index(ddl))
|
|
|
+
|
|
|
+ def gen_insert(self, table_name: str) -> str:
|
|
|
+ """拷贝 INSERT 语句"""
|
|
|
+ inserts = list(Convertor.inserts(table_name, self.content))
|
|
|
+
|
|
|
+ ## 生成 insert 脚本
|
|
|
+ script = ""
|
|
|
+ if inserts:
|
|
|
+ inserts_lines = "\n".join(inserts)
|
|
|
+ script += f"""\n\n-- ----------------------------
|
|
|
+-- Records of {table_name.lower()}
|
|
|
+-- ----------------------------
|
|
|
+-- @formatter:off
|
|
|
+SET IDENTITY_INSERT {table_name.lower()} ON;
|
|
|
+{inserts_lines}
|
|
|
+COMMIT;
|
|
|
+SET IDENTITY_INSERT {table_name.lower()} OFF;
|
|
|
+-- @formatter:on"""
|
|
|
+
|
|
|
+ return script
|
|
|
+
|
|
|
+
|
|
|
def main():
|
|
|
parser = argparse.ArgumentParser(description="芋道系统数据库转换工具")
|
|
|
parser.add_argument(
|
|
|
"type",
|
|
|
type=str,
|
|
|
help="目标数据库类型",
|
|
|
- choices=["postgres", "oracle", "sqlserver"],
|
|
|
+ choices=["postgres", "oracle", "sqlserver", "dm8"],
|
|
|
)
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
@@ -624,6 +728,8 @@ def main():
|
|
|
convertor = OracleConvertor(sql_file)
|
|
|
elif args.type == "sqlserver":
|
|
|
convertor = SQLServerConvertor(sql_file)
|
|
|
+ elif args.type == "dm8":
|
|
|
+ convertor = DM8Convertor(sql_file)
|
|
|
else:
|
|
|
raise NotImplementedError(f"不支持目标数据库类型: {args.type}")
|
|
|
|