From 576ed77d367dbc6cf9114fb50f1efd8e2ca13870 Mon Sep 17 00:00:00 2001 From: Wu Clan Date: Wed, 13 Nov 2024 19:26:17 +0800 Subject: [PATCH] Fix code generation file missing (#457) --- backend/app/generator/service/gen_service.py | 15 +++++++++++++++ backend/templates/py/api.jinja | 3 +-- backend/templates/py/crud.jinja | 15 +++++++++++---- backend/templates/py/model.jinja | 6 +++--- backend/templates/py/schema.jinja | 3 +-- backend/templates/py/service.jinja | 8 +++++--- 6 files changed, 36 insertions(+), 14 deletions(-) diff --git a/backend/app/generator/service/gen_service.py b/backend/app/generator/service/gen_service.py index cf376cd2..124d66fc 100644 --- a/backend/app/generator/service/gen_service.py +++ b/backend/app/generator/service/gen_service.py @@ -127,6 +127,17 @@ async def generate(self, *, pk: int) -> None: if not init_filepath.exists(): async with aiofiles.open(init_filepath, 'w', encoding='utf-8') as f: await f.write(gen_template.init_content) + if 'api' in str(code_folder): + # api __init__.py + api_init_filepath = code_folder.parent.joinpath('__init__.py') + if not api_init_filepath.exists(): + async with aiofiles.open(api_init_filepath, 'w', encoding='utf-8') as f: + await f.write(gen_template.init_content) + # app __init__.py + app_init_filepath = api_init_filepath.parent.joinpath('__init__.py') + if not app_init_filepath: + async with aiofiles.open(app_init_filepath, 'w', encoding='utf-8') as f: + await f.write(gen_template.init_content) # 写入代码文件呢 async with aiofiles.open(code_filepath, 'w', encoding='utf-8') as f: await f.write(code) @@ -161,6 +172,10 @@ async def download(self, *, pk: int) -> io.BytesIO: f'from backend.app.{business.app_name}.model.{business.table_name_en} ' f'import {to_pascal(business.table_name_en)}\n', ) + if 'api' in new_code_path: + # api __init__.py + api_init_filepath = os.path.join(*new_code_path.split('/')[:-2], '__init__.py') + zf.writestr(api_init_filepath, gen_template.init_content) zf.close() bio.seek(0) return bio diff --git a/backend/templates/py/api.jinja b/backend/templates/py/api.jinja index 1cbf936f..8a49af0b 100644 --- a/backend/templates/py/api.jinja +++ b/backend/templates/py/api.jinja @@ -2,8 +2,6 @@ # -*- coding: utf-8 -*- from typing import Annotated -from fastapi import APIRouter, Depends, Path, Query - from backend.app.{{ app_name }}.schema.{{ table_name_en }} import Create{{ schema_name }}Param, Get{{ schema_name }}ListDetails, Update{{ schema_name }}Param from backend.app.{{ app_name }}.service.{{ table_name_en }}_service import {{ table_name_en }}_service from backend.common.pagination import DependsPagination, paging_data @@ -12,6 +10,7 @@ from backend.common.security.jwt import DependsJwtAuth from backend.common.security.permission import RequestPermission from backend.common.security.rbac import DependsRBAC from backend.database.db_mysql import CurrentSession +from fastapi import APIRouter, Depends, Path, Query router = APIRouter() diff --git a/backend/templates/py/crud.jinja b/backend/templates/py/crud.jinja index 6a8fd820..d0192673 100644 --- a/backend/templates/py/crud.jinja +++ b/backend/templates/py/crud.jinja @@ -2,12 +2,11 @@ # -*- coding: utf-8 -*- from typing import Sequence -from sqlalchemy import delete -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy_crud_plus import CRUDPlus - from backend.app.{{ app_name }}.model import {{ table_name_class }} from backend.app.{{ app_name }}.schema.{{ table_name_en }} import Create{{ schema_name }}Param, Update{{ schema_name }}Param +from sqlalchemy import Select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy_crud_plus import CRUDPlus class CRUD{{ table_name_class }}(CRUDPlus[{{ schema_name }}]): @@ -21,6 +20,14 @@ class CRUD{{ table_name_class }}(CRUDPlus[{{ schema_name }}]): """ return await self.select_model(db, pk) + async def get_list(self) -> Select: + """ + 获取 {{ schema_name }} 列表 + + :return: + """ + return await self.select_order('created_time', 'desc') + async def get_all(self, db: AsyncSession) -> Sequence[{{ table_name_class }}]: """ 获取所有 {{ schema_name }} diff --git a/backend/templates/py/model.jinja b/backend/templates/py/model.jinja index 48d5155e..417172f7 100644 --- a/backend/templates/py/model.jinja +++ b/backend/templates/py/model.jinja @@ -1,13 +1,13 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +from datetime import datetime from uuid import UUID import sqlalchemy as sa -from sqlalchemy.dialects import mysql - -from sqlalchemy.orm import Mapped, mapped_column from backend.common.model import {% if have_datetime_column %}Base{% else %}MappedBase{% endif %}, id_key +from sqlalchemy.dialects import mysql +from sqlalchemy.orm import Mapped, mapped_column class {{ table_name_class }}({% if have_datetime_column %}Base{% else %}MappedBase{% endif %}): diff --git a/backend/templates/py/schema.jinja b/backend/templates/py/schema.jinja index 642806b3..8a066bd3 100644 --- a/backend/templates/py/schema.jinja +++ b/backend/templates/py/schema.jinja @@ -1,8 +1,7 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -{% if have_datetime_column %} from datetime import datetime -{% endif %} + from pydantic import ConfigDict from backend.common.schema import SchemaBase diff --git a/backend/templates/py/service.jinja b/backend/templates/py/service.jinja index 10e29b35..08cb4e01 100644 --- a/backend/templates/py/service.jinja +++ b/backend/templates/py/service.jinja @@ -7,6 +7,7 @@ from backend.app.{{ app_name }}.model import {{ table_name_class }} from backend.app.{{ app_name }}.schema.{{ table_name_en }} import Create{{ schema_name }}Param, Update{{ schema_name }}Param from backend.common.exception import errors from backend.database.db_mysql import async_db_session +from sqlalchemy import Select class {{ table_name_class }}Service: @@ -18,6 +19,10 @@ class {{ table_name_class }}Service: raise errors.NotFoundError(msg='{{ table_simple_name_zh }}不存在') return {{ table_name_en }} + @staticmethod + async def get_select() -> Select: + return await {{ table_name_en }}_dao.get_list() + @staticmethod async def get_all() -> Sequence[{{ table_name_class }}]: async with async_db_session() as db: @@ -27,9 +32,6 @@ class {{ table_name_class }}Service: @staticmethod async def create(*, obj: Create{{ schema_name }}Param) -> None: async with async_db_session.begin() as db: - {{ table_name_en }} = await {{ table_name_en }}_dao.get_by_name(db, obj.name) - if {{ table_name_en }}: - raise errors.ForbiddenError(msg='{{ table_simple_name_zh }}已存在') await {{ table_name_en }}_dao.create(db, obj) @staticmethod