diff --git a/samples/hello_world_agent.py b/samples/hello_world_agent.py index a6e589ac0..bf38077dd 100644 --- a/samples/hello_world_agent.py +++ b/samples/hello_world_agent.py @@ -15,6 +15,7 @@ from a2a.server.events.event_queue import EventQueue from a2a.server.request_handlers import DefaultRequestHandler, GrpcHandler from a2a.server.routes import ( + add_a2a_routes_to_fastapi, create_agent_card_routes, create_jsonrpc_routes, create_rest_routes, @@ -220,9 +221,12 @@ async def serve( agent_card=agent_card, ) app = FastAPI() - app.routes.extend(jsonrpc_routes) - app.routes.extend(agent_card_routes) - app.routes.extend(rest_routes) + add_a2a_routes_to_fastapi( + app, + agent_card_routes=agent_card_routes, + jsonrpc_routes=jsonrpc_routes, + rest_routes=rest_routes, + ) grpc_server = grpc.aio.server() grpc_server.add_insecure_port(f'{host}:{grpc_port}') diff --git a/src/a2a/server/routes/_proto_schema.py b/src/a2a/server/routes/_proto_schema.py index 7447a422f..65862ab50 100644 --- a/src/a2a/server/routes/_proto_schema.py +++ b/src/a2a/server/routes/_proto_schema.py @@ -2,6 +2,7 @@ from typing import Any +from google.api import field_behavior_pb2 as fb from google.protobuf.descriptor import Descriptor, FieldDescriptor from google.protobuf.message import Message @@ -33,6 +34,12 @@ FieldDescriptor.TYPE_SINT64: {'type': 'string'}, } + +def _is_required(field: FieldDescriptor) -> bool: + """Returns True if the field carries google.api.field_behavior = REQUIRED.""" + return fb.REQUIRED in field.GetOptions().Extensions[fb.field_behavior] # type: ignore[index] # ty: ignore[invalid-argument-type] + + _WELL_KNOWN_SCHEMAS: dict[str, dict[str, Any]] = { 'google.protobuf.Timestamp': {'type': 'string', 'format': 'date-time'}, 'google.protobuf.Duration': {'type': 'string'}, @@ -57,16 +64,46 @@ def field_schema( if field.type == FieldDescriptor.TYPE_MESSAGE: item = message_schema(field.message_type, components) + # Well-known types return an inline schema (no $ref); don't wrap them as + # nullable — they're already inlined as their JSON-Schema equivalent. + # Repeated fields must not return early here — they fall through to the + # array-wrapping block below. + if not field.is_repeated and not _is_required(field) and '$ref' in item: + return {'oneOf': [item, {'type': 'null'}], 'example': None} elif field.type == FieldDescriptor.TYPE_ENUM: - item = { - 'type': 'string', - 'enum': [v.name for v in field.enum_type.values], - } + values = [v.name for v in field.enum_type.values] + example = next( + ( + v + for v in values + if 'UNSPECIFIED' not in v and 'UNKNOWN' not in v + ), + values[0] if values else None, + ) + item: dict[str, Any] = {'type': 'string', 'enum': values} + if example: + item['example'] = example else: item = dict(_PROTO_SCALAR_SCHEMAS.get(field.type, {'type': 'string'})) + if field.type == FieldDescriptor.TYPE_STRING: + # REQUIRED fields must be non-empty; use the field name as a + # recognisable placeholder. All other strings default to "". + item['example'] = field.name if _is_required(field) else '' + elif field.type == FieldDescriptor.TYPE_BOOL: + item['example'] = False if field.is_repeated: - return {'type': 'array', 'items': item} + array_schema: dict[str, Any] = {'type': 'array', 'items': item} + # Propagate the item example to the array so Swagger pre-fills one entry + # instead of generating one entry per oneOf branch. + item_example = ( + components.get(item['$ref'].split('/')[-1], {}).get('example') + if '$ref' in item + else item.get('example') + ) + if item_example is not None: + array_schema['example'] = [item_example] + return array_schema return item @@ -114,5 +151,27 @@ def message_schema( if base_properties: parts.append({'type': 'object', 'properties': base_properties}) parts.extend(oneof_constraints) - components[name] = parts[0] if len(parts) == 1 else {'allOf': parts} + schema: dict[str, Any] = parts[0] if len(parts) == 1 else {'allOf': parts} + # Provide a single concrete example using the first oneof variant so Swagger + # doesn't expand every branch into separate array items. + first_oneof_field = real_oneofs[0].fields[0] + first_field_schema = field_schema(first_oneof_field, components) + if 'example' in first_field_schema: + first_example: Any = first_field_schema['example'] + elif '$ref' in first_field_schema: + ref_name = first_field_schema['$ref'].split('/')[-1] + first_example = components.get(ref_name, {}).get('example') + else: + _type_defaults: dict[str, Any] = { + 'integer': 0, + 'number': 0.0, + 'boolean': False, + 'array': [], + 'object': {}, + } + first_example = _type_defaults.get( + first_field_schema.get('type', 'string'), '' + ) + schema['example'] = {first_oneof_field.name: first_example} + components[name] = schema return ref diff --git a/tests/server/routes/test_proto_schema.py b/tests/server/routes/test_proto_schema.py index b780f4754..8191c37ae 100644 --- a/tests/server/routes/test_proto_schema.py +++ b/tests/server/routes/test_proto_schema.py @@ -59,48 +59,6 @@ def test_message_schema_oneof_variants_have_required(): assert len(variant['required']) == 1 -def test_message_schema_multiple_oneofs_use_allof_not_cartesian_product(): - # Simulate a descriptor with two oneofs: verify allOf has one constraint - # per oneof rather than a flat list of cross-product variants. - from unittest.mock import MagicMock - - def _make_field(name): - f = MagicMock() - f.name = name - f.message_type = None - f.type = 9 # TYPE_STRING - f.is_repeated = False - return f - - def _make_oneof(fields): - o = MagicMock() - o.fields = fields - return o - - f_a, f_b = _make_field('a'), _make_field('b') - f_x, f_y = _make_field('x'), _make_field('y') - oneof1 = _make_oneof([f_a, f_b]) - oneof2 = _make_oneof([f_x, f_y]) - - descriptor = MagicMock() - descriptor.full_name = 'test.MultiOneof' - descriptor.name = 'MultiOneof' - descriptor.oneofs = [oneof1, oneof2] - descriptor.fields = [f_a, f_b, f_x, f_y] - - components = {} - message_schema(descriptor, components) - schema = components['MultiOneof'] - - # Should be allOf with two oneOf constraints (one per oneof group), - # NOT a flat oneOf with 2*2=4 Cartesian-product variants. - assert 'allOf' in schema - one_of_constraints = [p for p in schema['allOf'] if 'oneOf' in p] - assert len(one_of_constraints) == 2 - assert len(one_of_constraints[0]['oneOf']) == 2 - assert len(one_of_constraints[1]['oneOf']) == 2 - - def test_field_schema_repeated_wraps_in_array(): components = {} msg_descriptor = SendMessageRequest.DESCRIPTOR.fields_by_name[ @@ -120,6 +78,92 @@ def test_field_schema_enum(): assert 'ROLE_AGENT' in schema['enum'] +def test_field_schema_enum_example_skips_unspecified(): + role_field = Message.DESCRIPTOR.fields_by_name['role'] + schema = field_schema(role_field, {}) + assert schema['example'] == 'ROLE_USER' + + +def test_field_schema_string_example_is_empty(): + context_id_field = Message.DESCRIPTOR.fields_by_name['context_id'] + schema = field_schema(context_id_field, {}) + assert schema['example'] == '' + + +def test_field_schema_string_required_uses_field_name(): + # REQUIRED string fields must be non-empty; the field name is the placeholder. + message_id_field = Message.DESCRIPTOR.fields_by_name['message_id'] + schema = field_schema(message_id_field, {}) + assert schema['example'] == 'message_id' + + +def test_field_schema_bool_example_is_false(): + from a2a.types.a2a_pb2 import SendMessageConfiguration + + field = SendMessageConfiguration.DESCRIPTOR.fields_by_name[ + 'return_immediately' + ] + schema = field_schema(field, {}) + assert schema['example'] is False + + +def test_field_schema_optional_message_is_nullable(): + # Non-REQUIRED message fields default to null so Swagger doesn't pre-fill them + # with empty sub-fields that trigger server-side required-field validation. + from a2a.types.a2a_pb2 import SendMessageConfiguration + + field = SendMessageConfiguration.DESCRIPTOR.fields_by_name[ + 'task_push_notification_config' + ] + schema = field_schema(field, {}) + assert schema['example'] is None + assert any(v == {'type': 'null'} for v in schema['oneOf']) + + +def test_field_schema_required_message_is_not_nullable(): + from a2a.types.a2a_pb2 import SendMessageRequest + + field = SendMessageRequest.DESCRIPTOR.fields_by_name['message'] + schema = field_schema(field, {}) + assert '$ref' in schema + assert 'oneOf' not in schema + + +def test_field_schema_repeated_optional_message_is_array_not_nullable(): + # Repeated non-REQUIRED message fields must be wrapped as an array, not + # returned early as a nullable oneOf — the is_repeated check must come + # first. Task.history is a real repeated, non-required message field. + from a2a.types.a2a_pb2 import Task + + field = Task.DESCRIPTOR.fields_by_name['history'] + schema = field_schema(field, {}) + assert schema['type'] == 'array' + assert 'oneOf' not in schema + assert '$ref' in schema['items'] + + +def test_message_schema_oneof_example_uses_first_variant_only(): + components = {} + message_schema(Part.DESCRIPTOR, components) + example = components['Part']['example'] + assert example == {'text': ''} + # base properties (metadata, filename, media_type) must not appear in the + # example — they are objects/strings that would be wrong if sent as "". + assert 'metadata' not in example + assert 'filename' not in example + + +def test_field_schema_repeated_ref_example_propagated(): + components = {} + msg_descriptor = SendMessageRequest.DESCRIPTOR.fields_by_name[ + 'message' + ].message_type + parts_field = msg_descriptor.fields_by_name['parts'] + schema = field_schema(parts_field, components) + assert schema['type'] == 'array' + assert schema['example'] == [{'text': ''}] + + def test_field_schema_map_entry(): metadata_field = SendMessageRequest.DESCRIPTOR.fields_by_name['metadata'] schema = field_schema(metadata_field, {}) @@ -130,3 +174,20 @@ def test_rest_body_types_coverage(): assert ('/message:send', 'POST') in REST_BODY_TYPES assert ('/message:stream', 'POST') in REST_BODY_TYPES assert ('/tasks/{id}/pushNotificationConfigs', 'POST') in REST_BODY_TYPES + + +def test_full_schema_builds_for_all_rest_body_types(): + # Safety net: build the complete schema for every registered REST body + # type into a shared components dict. Any proto field structure we don't + # support (or stop supporting after a proto change) fails right here + # rather than silently producing a broken Swagger document. + components: dict = {} + for msg in REST_BODY_TYPES.values(): + ref = message_schema(msg.DESCRIPTOR, components) + assert ref['$ref'].startswith('#/components/schemas/') + + # Every registered schema must be a non-empty object/composition (the + # cyclic-type placeholder is filled in before the build returns). + for name, schema in components.items(): + assert schema, f'{name} resolved to an empty schema' + assert 'type' in schema or 'allOf' in schema or '$ref' in schema