Skip to content

Commit

Permalink
Fix code generation file missing (#457)
Browse files Browse the repository at this point in the history
  • Loading branch information
wu-clan authored Nov 13, 2024
1 parent 06cbe56 commit 576ed77
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 14 deletions.
15 changes: 15 additions & 0 deletions backend/app/generator/service/gen_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,17 @@ async def generate(self, *, pk: int) -> None:
if not init_filepath.exists():
async with aiofiles.open(init_filepath, 'w', encoding='utf-8') as f:
await f.write(gen_template.init_content)
if 'api' in str(code_folder):
# api __init__.py
api_init_filepath = code_folder.parent.joinpath('__init__.py')
if not api_init_filepath.exists():
async with aiofiles.open(api_init_filepath, 'w', encoding='utf-8') as f:
await f.write(gen_template.init_content)
# app __init__.py
app_init_filepath = api_init_filepath.parent.joinpath('__init__.py')
if not app_init_filepath:
async with aiofiles.open(app_init_filepath, 'w', encoding='utf-8') as f:
await f.write(gen_template.init_content)
# 写入代码文件呢
async with aiofiles.open(code_filepath, 'w', encoding='utf-8') as f:
await f.write(code)
Expand Down Expand Up @@ -161,6 +172,10 @@ async def download(self, *, pk: int) -> io.BytesIO:
f'from backend.app.{business.app_name}.model.{business.table_name_en} '
f'import {to_pascal(business.table_name_en)}\n',
)
if 'api' in new_code_path:
# api __init__.py
api_init_filepath = os.path.join(*new_code_path.split('/')[:-2], '__init__.py')
zf.writestr(api_init_filepath, gen_template.init_content)
zf.close()
bio.seek(0)
return bio
Expand Down
3 changes: 1 addition & 2 deletions backend/templates/py/api.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
# -*- coding: utf-8 -*-
from typing import Annotated

from fastapi import APIRouter, Depends, Path, Query

from backend.app.{{ app_name }}.schema.{{ table_name_en }} import Create{{ schema_name }}Param, Get{{ schema_name }}ListDetails, Update{{ schema_name }}Param
from backend.app.{{ app_name }}.service.{{ table_name_en }}_service import {{ table_name_en }}_service
from backend.common.pagination import DependsPagination, paging_data
Expand All @@ -12,6 +10,7 @@ from backend.common.security.jwt import DependsJwtAuth
from backend.common.security.permission import RequestPermission
from backend.common.security.rbac import DependsRBAC
from backend.database.db_mysql import CurrentSession
from fastapi import APIRouter, Depends, Path, Query

router = APIRouter()

Expand Down
15 changes: 11 additions & 4 deletions backend/templates/py/crud.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
# -*- coding: utf-8 -*-
from typing import Sequence

from sqlalchemy import delete
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy_crud_plus import CRUDPlus

from backend.app.{{ app_name }}.model import {{ table_name_class }}
from backend.app.{{ app_name }}.schema.{{ table_name_en }} import Create{{ schema_name }}Param, Update{{ schema_name }}Param
from sqlalchemy import Select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy_crud_plus import CRUDPlus


class CRUD{{ table_name_class }}(CRUDPlus[{{ schema_name }}]):
Expand All @@ -21,6 +20,14 @@ class CRUD{{ table_name_class }}(CRUDPlus[{{ schema_name }}]):
"""
return await self.select_model(db, pk)

async def get_list(self) -> Select:
"""
获取 {{ schema_name }} 列表
:return:
"""
return await self.select_order('created_time', 'desc')

async def get_all(self, db: AsyncSession) -> Sequence[{{ table_name_class }}]:
"""
获取所有 {{ schema_name }}
Expand Down
6 changes: 3 additions & 3 deletions backend/templates/py/model.jinja
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from datetime import datetime
from uuid import UUID

import sqlalchemy as sa
from sqlalchemy.dialects import mysql

from sqlalchemy.orm import Mapped, mapped_column

from backend.common.model import {% if have_datetime_column %}Base{% else %}MappedBase{% endif %}, id_key
from sqlalchemy.dialects import mysql
from sqlalchemy.orm import Mapped, mapped_column


class {{ table_name_class }}({% if have_datetime_column %}Base{% else %}MappedBase{% endif %}):
Expand Down
3 changes: 1 addition & 2 deletions backend/templates/py/schema.jinja
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
{% if have_datetime_column %}
from datetime import datetime
{% endif %}

from pydantic import ConfigDict

from backend.common.schema import SchemaBase
Expand Down
8 changes: 5 additions & 3 deletions backend/templates/py/service.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ from backend.app.{{ app_name }}.model import {{ table_name_class }}
from backend.app.{{ app_name }}.schema.{{ table_name_en }} import Create{{ schema_name }}Param, Update{{ schema_name }}Param
from backend.common.exception import errors
from backend.database.db_mysql import async_db_session
from sqlalchemy import Select


class {{ table_name_class }}Service:
Expand All @@ -18,6 +19,10 @@ class {{ table_name_class }}Service:
raise errors.NotFoundError(msg='{{ table_simple_name_zh }}不存在')
return {{ table_name_en }}

@staticmethod
async def get_select() -> Select:
return await {{ table_name_en }}_dao.get_list()

@staticmethod
async def get_all() -> Sequence[{{ table_name_class }}]:
async with async_db_session() as db:
Expand All @@ -27,9 +32,6 @@ class {{ table_name_class }}Service:
@staticmethod
async def create(*, obj: Create{{ schema_name }}Param) -> None:
async with async_db_session.begin() as db:
{{ table_name_en }} = await {{ table_name_en }}_dao.get_by_name(db, obj.name)
if {{ table_name_en }}:
raise errors.ForbiddenError(msg='{{ table_simple_name_zh }}已存在')
await {{ table_name_en }}_dao.create(db, obj)

@staticmethod
Expand Down

0 comments on commit 576ed77

Please sign in to comment.