はじめに
FastAPIでDDD(ドメイン駆動設計)を実践していると、コードの全体像を把握するのが難しくなってきます。
レイヤーをまたいだクラスの関係を誰かに説明したり、自分でレビューしたりするとき、「図があればわかりやすいのに」と感じることが増えてきました。
Pythonにはクラス図を自動生成できる pyreverse というツールがあります。pylintに同梱されているのでpip install pylintで使えます。ただし、実際に使ってみると2つの課題がありました。
- スコープの絞り込みができない — プロジェクト全体が図になってしまい、特定のAPI1本だけの図が欲しいのに実現が難しい
- 型アノテーションの解析が甘い —
Optional[Order]やList[Order]のような型をクラス間の関係として認識してくれない
この2つを解決するラッパースクリプトを作ったので紹介します。
作ったもの
関数名・ファイル・ディレクトリを起点に、関連するファイルだけを自動収集してクラス図を生成するスクリプトです。
pyreverseをそのまま使いつつ、不足している型アノテーションの関係検出をPythonのAST解析で補完しています。
# 関数名を指定するだけで対応するクラス図が生成される python generate_class_diagram.py search_data app/management .
プロジェクト構成
対象のプロジェクトはこのようなDDD構成です。
app/ ├── management/ ← APIのエンドポイント(起点) ├── usecase/ ← ユースケース層 ├── domain/ ← ドメイン層 └── infrastructure/ ← インフラ層
各APIは async def search_data(...) のような関数として実装されており、レイヤーをまたいで複数のクラスに依存しています。
課題1の解決:importを辿って関連ファイルだけを収集する
pyreverseはディレクトリを丸ごと渡すしかなく、特定のAPIに関係するファイルだけを絞り込む機能がありません。
そこで、Pythonの ast モジュールでimport文を再帰的に追跡するロジックを実装しました。
def collect_local_imports(file_path: Path, root: Path, visited: set[Path]) -> set[Path]: if file_path in visited or not file_path.exists(): return visited visited.add(file_path) tree = ast.parse(file_path.read_text(encoding="utf-8")) for node in ast.walk(tree): if isinstance(node, ast.ImportFrom): if node.module and node.module.startswith("app"): _resolve_module(node.module, root, visited) return visited
from app.xxx import ... の形式のローカルimportだけを追跡し、外部ライブラリは無視します。起点ファイルから始まり、importの連鎖を再帰的に辿ることで、そのAPIが実際に使っているファイルだけを収集できます。
課題2の解決:型アノテーションの関係をAST解析で補完する
pyreverseが見落とす代表的なケースはこのようなものです。
class OrderUseCase: orders: List[Order] # → Orderとの関係が図に出ない current: Optional[Order] # → 同上 def execute(self, cmd: OrderCommand) -> OrderResult: # → 引数・戻り値も出ない ...
対策として、pyreverseでpumlを生成した後にAST解析で補完するpost-processingを実装しました。
型アノテーションの展開
Optional[X] や List[X] のような入れ子になった型から実際のクラス名を取り出す関数を実装しています。
def extract_type_names(annotation: ast.expr) -> list[str]: if isinstance(annotation, ast.Name): return [] if annotation.id in PRIMITIVE_TYPES else [annotation.id] if isinstance(annotation, ast.Subscript): outer = _get_name(annotation.value) if outer in CONTAINER_TYPES: # Optional, List, Set ... return extract_type_names(annotation.slice) if outer == "Union": return _extract_from_union(annotation.slice) # X | None (Python 3.10+) if isinstance(annotation, ast.BinOp) and isinstance(annotation.op, ast.BitOr): return [t for t in extract_type_names(annotation.left) + extract_type_names(annotation.right) if t != "None"] return []
対応している型パターンはこの通りです。
| 書き方 | 検出されるクラス |
|---|---|
Optional[Order] | Order |
List[Order] / list[Order] | Order |
Order | None(Python 3.10+) | Order |
Union[Order, None] | Order |
Set[Order] / Sequence[Order] | Order |
dict[str, Order] | Order(値の型) |
pumlへの補完
検出した関係のうち、pyreverseが既に出力しているものと重複しないものだけを @enduml の直前に追記します。
def patch_puml(puml_path: Path, relations: dict[str, set[str]], known_classes: set[str]): content = puml_path.read_text(encoding="utf-8") existing = set(re.findall(r'(\w+)\s+(?:--|-->|\.\.>)\s*(?:\S+\s+)*(\w+)', content)) new_lines = [] for cls, related_set in relations.items(): for related in related_set: if related in known_classes and (cls, related) not in existing: new_lines.append(f"{cls} --> {related} : uses") if new_lines: content = content.replace("@enduml", "\n".join(new_lines) + "\n@enduml") puml_path.write_text(content, encoding="utf-8")
おまけ:関数名でファイルを自動特定する
毎回ファイルパスを調べて入力するのが面倒なので、関数名から対象ファイルを自動検索する機能も追加しました。
def find_file_by_function(func_name: str, search_dir: Path) -> Path | None: for file_path in sorted(search_dir.rglob("*.py")): tree = ast.parse(file_path.read_text(encoding="utf-8")) for node in ast.walk(tree): if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): if node.name == func_name: return file_path return None
def / async def どちらも対応しており、デコレータの種類(@router.post か否か)は関係ありません。
使い方まとめ
# 関数名で検索(一番よく使う) python generate_class_diagram.py search_data app/management . # ファイル直接指定 python generate_class_diagram.py app/management/search_data.py . # ディレクトリ全体を一括生成 python generate_class_diagram.py app/management .
実行すると追跡したファイルの一覧と補完した関係数が表示されます。
関数名指定モード: search_data 検索ディレクトリ: app/management 発見: app/management/search_data.py ============================================================ 処理中: app/management/search_data.py 追跡ファイル数: 8件 app/usecase/search/search_usecase.py app/domain/search/search_model.py ... -> 出力: classes_search_data.puml -> 3件の関係を補完 (Optional/List等): SearchUseCase->Order, ...
できないこと(正直に)
pyreverseの上位互換を目指しましたが、以下の点は現状できていません。
- 関係の種類の区別 — 補完した関係はすべて
-->になります。コンポジション(*--)や集約(o--)は区別できません - 多重度 —
"1"/"*"/"0..1"などの多重度は補完部分には付きません - アノテーションなしの型推論 —
self.repo = OrderRepository()のような代入からは検出できません
とはいえ、DDDの実装ではほとんどの依存関係が型アノテーションで明示されているため、実用上は十分に機能しています。
まとめ
pyreverseは便利ですが、そのままでは大規模なDDD構成には使いにくい面があります。
今回作ったスクリプトで「特定のAPIに関係するクラスだけ」「Optional/Listの関係も含めて」図にできるようになり、設計の見直しやレビューの効率が上がりました。
コード全体はGitHubに公開予定です。同じような悩みを持っている方の参考になれば幸いです。
"""
pyreverse用クラス図生成スクリプト
importを再帰追跡して個別のクラス図を生成する
【関数名指定】特定の関数を持つファイルを自動検索して生成:
python generate_class_diagram.py <関数名> <検索ディレクトリ> [プロジェクトルート]
例: python generate_class_diagram.py search_data app/management .
【ファイル指定】直接ファイルを指定して生成:
python generate_class_diagram.py <ファイルパス> [プロジェクトルート]
例: python generate_class_diagram.py app/management/search_data.py .
【ディレクトリ指定】配下の全ファイルを一括生成:
python generate_class_diagram.py <ディレクトリ> [プロジェクトルート]
例: python generate_class_diagram.py app/management .
プロジェクトルートのデフォルトは . (カレントディレクトリ)
"""
import ast
import re
import subprocess
import sys
from pathlib import Path
# コンテナ型(中身の型を関係として抽出する)
CONTAINER_TYPES = {"Optional", "List", "list", "Set", "set", "Sequence", "Iterable", "FrozenSet", "frozenset", "Tuple", "tuple"}
# 無視するプリミティブ型
PRIMITIVE_TYPES = {"str", "int", "float", "bool", "bytes", "None", "Any", "dict", "Dict", "type"}
def extract_type_names(annotation: ast.expr) -> list[str]:
"""アノテーションノードから実際のクラス名を抽出(Optional/List等を展開)"""
if isinstance(annotation, ast.Name):
name = annotation.id
return [] if name in PRIMITIVE_TYPES else [name]
if isinstance(annotation, ast.Attribute):
return [annotation.attr]
if isinstance(annotation, ast.Subscript):
outer = _get_name(annotation.value)
if outer in CONTAINER_TYPES:
return extract_type_names(annotation.slice)
if outer in ("Union",):
return _extract_from_union(annotation.slice)
if outer in ("Dict", "dict"):
# dict[K, V] → Vの型のみ対象
if isinstance(annotation.slice, ast.Tuple) and len(annotation.slice.elts) >= 2:
return extract_type_names(annotation.slice.elts[1])
return []
# X | None (Python 3.10+)
if isinstance(annotation, ast.BinOp) and isinstance(annotation.op, ast.BitOr):
return [
t for t in extract_type_names(annotation.left) + extract_type_names(annotation.right)
if t != "None"
]
# Union[X, Y] のsliceがTupleの場合
if isinstance(annotation, ast.Tuple):
return _extract_from_union(annotation)
return []
def _get_name(node: ast.expr) -> str:
if isinstance(node, ast.Name):
return node.id
if isinstance(node, ast.Attribute):
return node.attr
return ""
def _extract_from_union(node: ast.expr) -> list[str]:
results = []
if isinstance(node, ast.Tuple):
for elt in node.elts:
if isinstance(elt, ast.Constant) and elt.value is None:
continue
results.extend(extract_type_names(elt))
else:
results.extend(extract_type_names(node))
return results
def collect_class_relations(files: list[Path]) -> dict[str, set[str]]:
"""収集したファイル群からクラスのアノテーションを解析して関係を返す"""
relations: dict[str, set[str]] = {}
for file_path in files:
try:
tree = ast.parse(file_path.read_text(encoding="utf-8"))
except (SyntaxError, UnicodeDecodeError):
continue
for node in ast.walk(tree):
if not isinstance(node, ast.ClassDef):
continue
cls_name = node.name
if cls_name not in relations:
relations[cls_name] = set()
# 継承元クラス
for base in node.bases:
for t in extract_type_names(base):
relations[cls_name].add(t)
for item in node.body:
# クラス変数のアノテーション: x: Optional[Foo] = ...
if isinstance(item, ast.AnnAssign):
for t in extract_type_names(item.annotation):
relations[cls_name].add(t)
# メソッドの戻り値アノテーションのみ対象
# 引数はDBセッション等の共通クラスが混入するため除外
if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)):
if item.returns:
for t in extract_type_names(item.returns):
relations[cls_name].add(t)
return relations
def patch_puml(puml_path: Path, relations: dict[str, set[str]], known_classes: set[str]):
"""pyreverseが見落としたOptional/List等の関係をpumlに補完する"""
if not puml_path.exists():
return
content = puml_path.read_text(encoding="utf-8")
# 既存の関係を抽出(重複追加を防ぐ)
existing = set(re.findall(r'(\w+)\s+(?:--|-->|\.\.>)\s*(?:\S+\s+)*(\w+)', content))
new_lines = []
for cls, related_set in relations.items():
if cls not in known_classes:
continue
for related in related_set:
if related not in known_classes:
continue
if (cls, related) not in existing and (related, cls) not in existing:
new_lines.append(f"{cls} --> {related} : uses")
existing.add((cls, related))
if new_lines:
content = content.replace("@enduml", "\n".join(new_lines) + "\n@enduml")
puml_path.write_text(content, encoding="utf-8")
print(f" -> {len(new_lines)}件の関係を補完 (Optional/List等): {', '.join(l.split()[0] + '->' + l.split()[2] for l in new_lines)}")
def _clean_class_block(block: str) -> str:
"""クラスブロックからアクセス修飾子記号とdunderメソッドを除去する"""
lines = []
for line in block.splitlines():
stripped = line.lstrip()
# __xxx__ や __xxx のようなdunderは除外
if re.match(r'[+\-#~]?\s*__\w+', stripped):
continue
# アクセス修飾子記号を除去
line = re.sub(r'^(\s*)[+\-#~]', r'\1', line)
lines.append(line)
return "\n".join(lines)
def build_class_module_map(files: list[Path], root: Path) -> dict[str, str]:
"""クラス名 → モジュールパス のマッピングを構築"""
mapping = {}
for file_path in files:
try:
tree = ast.parse(file_path.read_text(encoding="utf-8"))
rel = file_path.relative_to(root)
module = str(rel.parent).replace("\\", ".").replace("/", ".")
except (SyntaxError, UnicodeDecodeError, ValueError):
continue
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef):
mapping[node.name] = module
return mapping
def _brace_depth(line: str) -> int:
"""{field}/{method} 等のマーカーを除いてブロック深さの変化を計算"""
cleaned = re.sub(r'\{[a-z]+\}', '', line)
return cleaned.count("{") - cleaned.count("}")
def extract_class_blocks(content: str) -> dict[str, str]:
"""classes puml からクラス定義ブロックを抽出 {クラス名: ブロック文字列}
クォートあり/なし・フルモジュールパス形式・{field}/{method}マーカーに対応
例: class Foo / class "Foo" / class "app.domain.Foo"
"""
blocks = {}
lines = content.splitlines()
i = 0
while i < len(lines):
line = lines[i]
m = re.match(r'^\s*(?:abstract\s+)?class\s+"?([\w.]+)"?', line)
if m:
# フルパスの場合は最後のコンポーネントのみ使用 (app.domain.Foo → Foo)
cls_name = m.group(1).split(".")[-1]
block = [line]
depth = _brace_depth(line)
i += 1
while i < len(lines) and depth > 0:
block.append(lines[i])
depth += _brace_depth(lines[i])
i += 1
blocks[cls_name] = "\n".join(block)
else:
i += 1
return blocks
def extract_relationship_lines(content: str) -> list[str]:
"""puml から関係行を抽出"""
rel_pattern = re.compile(r'^\s*\S+.*?(?:-->|--\*|\*--|o--|--o|\.\.>|<\|--|--\|>).*\S')
return [line.strip() for line in content.splitlines() if rel_pattern.match(line)]
def _build_package_tree(module_classes: dict[str, list[str]]) -> dict:
"""モジュールパス(ドット区切り)をネストした木構造に変換"""
tree: dict = {}
for module, blocks in module_classes.items():
node = tree
for part in module.split("."):
node = node.setdefault(part, {})
node.setdefault("_classes", []).extend(blocks)
return tree
def _render_package_tree(tree: dict, indent: int = 0) -> list[str]:
"""木構造をネストした package ブロックとしてレンダリング"""
lines = []
pad = " " * indent
for key in sorted(k for k in tree if k != "_classes"):
subtree = tree[key]
lines.append(f'{pad}package "{key}" {{')
for block in subtree.get("_classes", []):
for line in block.splitlines():
lines.append(f"{pad} {line}")
child_keys = [k for k in subtree if k != "_classes"]
if child_keys:
lines.extend(_render_package_tree(
{k: subtree[k] for k in child_keys}, indent + 1
))
lines.append(f"{pad}}}")
lines.append("")
return lines
def _filter_class_block_methods(block: str, keep_methods: set[str]) -> str:
"""クラスブロックから呼び出されたメソッドのみ残す。属性行は常に残す。"""
lines = block.splitlines()
result = []
for line in lines:
stripped = line.strip()
# class行・ブレース・空行は常に残す
if not stripped or stripped in ("{", "}") or re.match(r'(?:abstract\s+)?class\b', stripped):
result.append(line)
continue
# メソッド行判定: word( で始まる
m = re.match(r'\s*(\w+)\s*\(', line)
if m:
if m.group(1) in keep_methods:
result.append(line)
# 呼ばれていないメソッドはスキップ
else:
# 属性行(フィールド定義)は残す
result.append(line)
return "\n".join(result)
def merge_diagrams(
classes_puml: Path,
class_module_map: dict[str, str],
diagram_name: str,
output_dir: Path,
entry_func: tuple[str, str, set[str]] | None = None,
extra_relations: dict[str, dict[str, set[str]]] | None = None,
):
"""クラス図をディレクトリ階層をネストしたパッケージ構造で再出力する。
entry_func = (func_name, entry_module, direct_uses): 関数ノードと矢印を追加
extra_relations = {src: {dep: {methods}}}: メソッド名ラベル付き矢印と使用メソッドフィルタに使用
"""
if not classes_puml.exists():
return
classes_content = classes_puml.read_text(encoding="utf-8")
raw_blocks = extract_class_blocks(classes_content)
known_classes = set(class_module_map.keys())
# 各クラスに到達したメソッド名を集約(フィルタ用)
used_methods_per_class: dict[str, set[str]] = {}
if extra_relations:
for dep_map in extra_relations.values():
for tgt_cls, methods in dep_map.items():
used_methods_per_class.setdefault(tgt_cls, set()).update(methods)
# reachable_classes でフィルタし、使用メソッドのみ残す
class_blocks = {}
for name, block in raw_blocks.items():
if name not in known_classes:
continue
cleaned = _clean_class_block(block)
if name in used_methods_per_class:
cleaned = _filter_class_block_methods(cleaned, used_methods_per_class[name])
class_blocks[name] = cleaned
# pyreverse由来の関係: 両端が既知クラスのものだけ残す
class_relations = [
line for line in extract_relationship_lines(classes_content)
if all(c in known_classes for c in re.findall(r'\b([A-Z]\w+)\b', line))
]
# extra_relations(BFS由来)をメソッド名ラベル付きで追加
if extra_relations:
existing_pairs: set[tuple[str, str]] = set()
for line in class_relations:
names = re.findall(r'\b([A-Z]\w+)\b', line)
if len(names) >= 2:
existing_pairs.add((names[0], names[-1]))
for src_cls, dep_map in extra_relations.items():
if src_cls not in known_classes:
continue
for dep_cls, methods in sorted(dep_map.items()):
if dep_cls not in known_classes or (src_cls, dep_cls) in existing_pairs:
continue
label = ", ".join(sorted(methods)) if methods else ""
arrow = f"{src_cls} --> {dep_cls}" + (f" : {label}" if label else "")
class_relations.append(arrow)
existing_pairs.add((src_cls, dep_cls))
# クラスをモジュールでグループ化
module_classes: dict[str, list[str]] = {}
for cls_name, block in class_blocks.items():
module = class_module_map[cls_name]
module_classes.setdefault(module, []).append(block)
# 関数ノードをパッケージ内に追加し、直接使用クラスへの矢印を生成
if entry_func:
ef_name, ef_module, ef_uses = entry_func
if ef_module:
func_block = f'() "{ef_name}"'
module_classes.setdefault(ef_module, []).insert(0, func_block)
for cls in ef_uses:
if cls in known_classes:
class_relations.append(f'"{ef_name}" --> {cls}')
# ネストしたパッケージツリーを構築・レンダリング
tree = _build_package_tree(module_classes)
out = [f"@startuml merged_{diagram_name}", ""]
out.extend(_render_package_tree(tree))
if class_relations:
out.append("' クラス間の関係")
out.extend(class_relations)
out.append("")
out.append("@enduml")
merged_path = output_dir / f"merged_{diagram_name}.puml"
merged_path.write_text("\n".join(out), encoding="utf-8")
print(f" -> 階層出力: {merged_path}")
def build_import_map(file_path: Path) -> dict[str, tuple[str, str]]:
"""ファイルのimport文から {ローカル名: (モジュール, 元の名前)} を構築"""
try:
tree = ast.parse(file_path.read_text(encoding="utf-8"))
except (SyntaxError, UnicodeDecodeError):
return {}
import_map: dict[str, tuple[str, str]] = {}
for node in ast.walk(tree):
if isinstance(node, ast.ImportFrom) and node.module:
for alias in node.names:
local = alias.asname or alias.name
import_map[local] = (node.module, alias.name)
elif isinstance(node, ast.Import):
for alias in node.names:
local = alias.asname or alias.name
import_map[local] = (alias.name, alias.name)
return import_map
def _file_defines_class(file_path: Path, class_name: str) -> bool:
"""ファイルが指定クラスを定義しているか確認"""
try:
tree = ast.parse(file_path.read_text(encoding="utf-8"))
except (SyntaxError, UnicodeDecodeError):
return False
return any(
isinstance(node, ast.ClassDef) and node.name == class_name
for node in ast.walk(tree)
)
def _get_classdef(file_path: Path, class_name: str) -> ast.ClassDef | None:
"""ファイルからクラス定義ASTノードを取得"""
try:
tree = ast.parse(file_path.read_text(encoding="utf-8"))
except (SyntaxError, UnicodeDecodeError):
return None
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef) and node.name == class_name:
return node
return None
def find_class_file(module: str, class_name: str, root: Path) -> Path | None:
"""モジュールとクラス名からクラスを実際に定義しているファイルを特定する"""
if _is_excluded_module(module):
return None
mod_path = root / Path(module.replace(".", "/"))
direct_py = mod_path.with_suffix(".py")
if direct_py.exists():
return direct_py if _file_defines_class(direct_py, class_name) else None
if mod_path.is_dir():
for py_file in sorted(mod_path.glob("*.py")):
if py_file.name == "__init__.py":
continue
if _file_defines_class(py_file, class_name):
return py_file
init_py = mod_path / "__init__.py"
if init_py.exists() and _file_defines_class(init_py, class_name):
return init_py
return None
def class_dependencies(classdef: ast.ClassDef) -> dict[str, set[str]]:
"""クラス定義から {依存クラス名: {呼び出しメソッド名}} を返す。
メソッド名が空集合 = 型参照のみ(アノテーション・継承)。
self.field.method() パターンでメソッド名を記録する。
"""
deps: dict[str, set[str]] = {}
def _add(name: str, method: str = "") -> None:
if name and name not in PRIMITIVE_TYPES:
deps.setdefault(name, set())
if method:
deps[name].add(method)
# 継承元
for base in classdef.bases:
for t in extract_type_names(base):
_add(t)
# フィールドの型マップ: self.field.method() → 型解決に使用
# 属性アノテーション と @property 戻り値の両方を登録する
field_type_map: dict[str, str] = {}
for item in classdef.body:
if isinstance(item, ast.AnnAssign) and isinstance(item.target, ast.Name):
types = extract_type_names(item.annotation)
if types:
field_type_map[item.target.id] = types[0]
_add(types[0])
elif isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)) and item.returns:
is_property = any(
(isinstance(d, ast.Name) and d.id == "property") or
(isinstance(d, ast.Attribute) and d.attr == "property")
for d in item.decorator_list
)
if is_property:
types = extract_type_names(item.returns)
if types:
field_type_map[item.name] = types[0]
_add(types[0])
for item in classdef.body:
if not isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)):
continue
# 戻り値アノテーション
if item.returns:
for t in extract_type_names(item.returns):
_add(t)
# メソッドbody
for stmt in item.body:
for node in ast.walk(stmt):
if not isinstance(node, ast.Call):
continue
func = node.func
# ClassX() 直接インスタンス化
if isinstance(func, ast.Name):
_add(func.id)
# self.field.method() パターン: フィールド経由のメソッド呼び出し
elif (isinstance(func, ast.Attribute) and
isinstance(func.value, ast.Attribute) and
isinstance(func.value.value, ast.Name) and
func.value.value.id == "self"):
field_name = func.value.attr
method_name = func.attr
if field_name in field_type_map:
_add(field_type_map[field_name], method_name)
return deps
def resolve_symbol(
name: str,
import_map: dict[str, tuple[str, str]],
context_file: Path,
root: Path,
) -> tuple[str, Path] | None:
"""クラス名を (正式クラス名, 定義ファイル) に解決する"""
if name in import_map:
module, original_name = import_map[name]
if module.startswith("app") and not _is_excluded_module(module):
cls_file = find_class_file(module, original_name, root)
if cls_file:
return original_name, cls_file
if _file_defines_class(context_file, name):
return name, context_file
return None
def trace_from_function(
entry_file: Path, func_name: str, root: Path
) -> tuple[set[Path], dict[str, str], dict[str, dict[str, set[str]]], set[str]]:
"""関数名を起点にクラスレベルのBFSで到達クラスを収集する(ファイル単位の再帰は行わない)。
戻り値: (collected_files, class_module_map, relations, direct_uses)
relations: {src_class: {dep_class: {method_names}}}
"""
entry_import_map = build_import_map(entry_file)
try:
tree = ast.parse(entry_file.read_text(encoding="utf-8"))
except (SyntaxError, UnicodeDecodeError):
return {entry_file}, {}, {}, set()
target_func = None
for node in ast.walk(tree):
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) and node.name == func_name:
target_func = node
break
if target_func is None:
return {entry_file}, {}, {}, set()
# 関数body内の直接呼び出し ClassX() のみ収集(module.X() は除外)
used_names: set[str] = set()
for stmt in target_func.body:
for node in ast.walk(stmt):
if isinstance(node, ast.Call) and isinstance(node.func, ast.Name):
used_names.add(node.func.id)
# 種: 起点関数が直接参照するクラス
worklist: list[tuple[str, Path]] = []
direct_uses: set[str] = set()
for local_name in used_names:
resolved = resolve_symbol(local_name, entry_import_map, entry_file, root)
if resolved:
cls_name, cls_file = resolved
worklist.append((cls_name, cls_file))
direct_uses.add(cls_name)
# BFS: クラスを辿るたびにそのクラス定義の依存だけを展開する
visited: set[tuple[str, Path]] = set()
collected_files: set[Path] = {entry_file}
class_module_map: dict[str, str] = {}
relations: dict[str, dict[str, set[str]]] = {}
while worklist:
cls_name, cls_file = worklist.pop(0)
key = (cls_name, cls_file)
if key in visited:
continue
visited.add(key)
collected_files.add(cls_file)
try:
rel = cls_file.relative_to(root)
mod_str = str(rel.parent).replace("\\", ".").replace("/", ".")
class_module_map[cls_name] = mod_str if mod_str != "." else ""
except ValueError:
class_module_map[cls_name] = ""
classdef = _get_classdef(cls_file, cls_name)
if classdef is None:
continue
cls_import_map = build_import_map(cls_file)
deps_with_methods = class_dependencies(classdef)
relations[cls_name] = {}
for dep_local, methods in deps_with_methods.items():
resolved = resolve_symbol(dep_local, cls_import_map, cls_file, root)
if resolved:
dep_cls_name, dep_file = resolved
# 同一依存先が複数経路で現れる場合はメソッドセットをマージ
relations[cls_name].setdefault(dep_cls_name, set()).update(methods)
if (dep_cls_name, dep_file) not in visited:
worklist.append((dep_cls_name, dep_file))
return collected_files, class_module_map, relations, direct_uses
def find_file_by_function(func_name: str, search_dir: Path) -> Path | None:
"""指定した関数名(def/async def)を含むファイルをディレクトリから検索する"""
for file_path in sorted(search_dir.rglob("*.py")):
if file_path.name == "__init__.py":
continue
try:
tree = ast.parse(file_path.read_text(encoding="utf-8"))
except (SyntaxError, UnicodeDecodeError):
continue
for node in ast.walk(tree):
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) and node.name == func_name:
return file_path
return None
def _is_excluded_module(module_name: str) -> bool:
"""_ または __ で始まるディレクトリ(共通・内部パッケージ)を除外する"""
return any(part.startswith("_") for part in module_name.split("."))
def collect_local_imports(file_path: Path, root: Path, visited: set[Path]) -> set[Path]:
if file_path in visited or not file_path.exists():
return visited
visited.add(file_path)
try:
tree = ast.parse(file_path.read_text(encoding="utf-8"))
except (SyntaxError, UnicodeDecodeError) as e:
print(f" [スキップ] {file_path.name}: {e}")
return visited
for node in ast.walk(tree):
if isinstance(node, ast.ImportFrom):
if node.module and node.module.startswith("app") and not _is_excluded_module(node.module):
imported_names = [alias.name for alias in node.names]
_resolve_module(node.module, imported_names, root, visited)
elif isinstance(node, ast.Import):
for alias in node.names:
if alias.name.startswith("app") and not _is_excluded_module(alias.name):
_resolve_module(alias.name, [], root, visited)
return visited
def _resolve_module(module_name: str, imported_names: list[str], root: Path, visited: set[Path]):
"""import された名前に対応するファイルを直接特定して追跡する。
__init__.py 経由の一括取得を避け、関係のないクラスの混入を防ぐ。
"""
mod_path = root / Path(module_name.replace(".", "/"))
# from app.domain.order import Order → app/domain/order.py が存在する場合
direct_py = mod_path.with_suffix(".py")
if direct_py.exists():
collect_local_imports(direct_py, root, visited)
return
# パッケージ(ディレクトリ)の場合、import した名前のファイルを直接探す
if mod_path.is_dir():
resolved = False
for name in imported_names:
if name == "*":
continue
# app/domain/order/order.py のようなサブモジュールファイル
submodule_py = mod_path / f"{name}.py"
if submodule_py.exists():
collect_local_imports(submodule_py, root, visited)
resolved = True
continue
# サブパッケージ(ディレクトリ)
subpkg_init = mod_path / name / "__init__.py"
if subpkg_init.exists():
collect_local_imports(subpkg_init, root, visited)
resolved = True
# 個別ファイルが見つからなかった場合のみ __init__.py を追跡
if not resolved:
init_py = mod_path / "__init__.py"
if init_py.exists():
collect_local_imports(init_py, root, visited)
def generate_diagram(entry: Path, root: Path, output_dir: Path, func_name: str | None = None):
print(f"\n{'='*60}")
print(f"処理中: {entry.relative_to(root)}" + (f" (関数: {func_name})" if func_name else ""))
if func_name:
collected_files, class_module_map, relations, direct_uses = trace_from_function(entry, func_name, root)
py_files_paths = sorted(f for f in collected_files if f.suffix == ".py")
rel = entry.relative_to(root)
ef_module_raw = str(rel.parent).replace("\\", ".").replace("/", ".")
entry_module = ef_module_raw if ef_module_raw and ef_module_raw != "." else None
else:
files = collect_local_imports(entry, root, set())
py_files_paths = sorted(f for f in files if f.suffix == ".py")
class_module_map = build_class_module_map(py_files_paths, root)
relations = collect_class_relations(py_files_paths)
direct_uses = set()
entry_module = None
py_files = [str(f) for f in py_files_paths]
print(f" 追跡ファイル数: {len(py_files)}件")
for f in py_files_paths:
try:
print(f" {f.relative_to(root)}")
except ValueError:
print(f" {f}")
diagram_name = entry.stem
result = subprocess.run(
["pyreverse", "-o", "puml", "-p", diagram_name] + py_files,
capture_output=True,
text=True,
cwd=str(output_dir),
)
if result.returncode != 0:
print(f" [エラー] pyreverse失敗:")
print(f" {result.stderr.strip()}")
return
classes_puml = output_dir / f"classes_{diagram_name}.puml"
print(f" -> 出力: {classes_puml}")
ef = (func_name, entry_module, direct_uses) if func_name and entry_module else None
if func_name:
# 関数名モード: BFS由来のrelationsが既に正確なのでpatch_pumlはスキップ
# extra_relationsとしてmerge_diagramsに渡し、reachable_classesでフィルタして出力
merge_diagrams(classes_puml, class_module_map, diagram_name, output_dir,
entry_func=ef, extra_relations=relations)
else:
# ファイル/ディレクトリモード: 従来通りpatch_pumlで補完してからマージ
known_classes = set(class_module_map.keys())
patch_puml(classes_puml, relations, known_classes)
merge_diagrams(classes_puml, class_module_map, diagram_name, output_dir, entry_func=ef)
def main():
if len(sys.argv) < 2:
print(__doc__)
sys.exit(1)
output_dir = Path(".").resolve()
first_arg = sys.argv[1]
# 【関数名指定モード】パスとして存在しない かつ 引数が2つ以上ある場合
# 例: python generate_class_diagram.py search_data app/management .
first_as_path = Path(first_arg)
if not first_as_path.exists() and len(sys.argv) >= 3:
func_name = first_arg
search_dir_arg = sys.argv[2]
project_root = sys.argv[3] if len(sys.argv) > 3 else "."
root = Path(project_root).resolve()
search_dir = (root / search_dir_arg).resolve()
if not search_dir.exists():
print(f"エラー: 検索ディレクトリが見つかりません -> {search_dir}")
sys.exit(1)
print(f"関数名指定モード: {func_name}")
print(f"検索ディレクトリ: {search_dir.relative_to(root)}")
entry = find_file_by_function(func_name, search_dir)
if entry is None:
print(f"エラー: def {func_name} が見つかりませんでした")
sys.exit(1)
print(f"発見: {entry.relative_to(root)}")
generate_diagram(entry, root, output_dir, func_name=func_name)
print("\n完了")
return
# 【ファイル / ディレクトリ指定モード】
target = first_arg
project_root = sys.argv[2] if len(sys.argv) > 2 else "."
root = Path(project_root).resolve()
target_path = (root / target).resolve()
if not target_path.exists():
print(f"エラー: 見つかりません -> {target_path}")
print(__doc__)
sys.exit(1)
print(f"プロジェクトルート: {root}")
print(f"出力先: {output_dir}")
# ファイル指定
if target_path.is_file():
print(f"\nファイル指定モード: {target_path.relative_to(root)}")
generate_diagram(target_path, root, output_dir)
print("\n完了")
return
# ディレクトリ指定: __init__.py を除く全 .py ファイルを対象
print(f"\nディレクトリ指定モード: {target_path.relative_to(root)}")
entry_files = [
f for f in sorted(target_path.rglob("*.py"))
if f.name != "__init__.py"
]
if not entry_files:
print("対象ファイルが見つかりませんでした。")
sys.exit(1)
print(f"{len(entry_files)}件のファイルを対象:")
for f in entry_files:
print(f" {f.relative_to(root)}")
for entry in entry_files:
generate_diagram(entry, root, output_dir)
print(f"\n{'='*60}")
print(f"完了: {len(entry_files)}件のクラス図を生成しました")
if __name__ == "__main__":
main()
