Skip to content

Commit be12eed

Browse files
authored
feat(firebaseai): add bidi transcript (#17700)
* Add transcription into server content * Add transcription config * transcription working * add test for transcription * apply gemini suggestion * fix the transcript index in bidi_page * more review comment
1 parent f9ca819 commit be12eed

File tree

6 files changed

+178
-5
lines changed

6 files changed

+178
-5
lines changed

packages/firebase_ai/firebase_ai/example/lib/pages/bidi_page.dart

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ class _BidiPageState extends State<BidiPage> {
5757
StreamController<bool> _stopController = StreamController<bool>();
5858
final AudioOutput _audioOutput = AudioOutput();
5959
final AudioInput _audioInput = AudioInput();
60+
int? _inputTranscriptionMessageIndex;
61+
int? _outputTranscriptionMessageIndex;
6062

6163
@override
6264
void initState() {
@@ -67,6 +69,8 @@ class _BidiPageState extends State<BidiPage> {
6769
responseModalities: [
6870
ResponseModalities.audio,
6971
],
72+
inputAudioTranscription: AudioTranscriptionConfig(),
73+
outputAudioTranscription: AudioTranscriptionConfig(),
7074
);
7175

7276
// ignore: deprecated_member_use
@@ -353,6 +357,49 @@ class _BidiPageState extends State<BidiPage> {
353357
if (message.modelTurn != null) {
354358
await _handleLiveServerContent(message);
355359
}
360+
361+
int? _handleTranscription(
362+
Transcription? transcription,
363+
int? messageIndex,
364+
String prefix,
365+
bool fromUser,
366+
) {
367+
int? currentIndex = messageIndex;
368+
if (transcription?.text != null) {
369+
if (currentIndex != null) {
370+
_messages[currentIndex] = _messages[currentIndex].copyWith(
371+
text: '${_messages[currentIndex].text}${transcription!.text!}',
372+
);
373+
} else {
374+
_messages.add(
375+
MessageData(
376+
text: '$prefix${transcription!.text!}',
377+
fromUser: fromUser,
378+
),
379+
);
380+
currentIndex = _messages.length - 1;
381+
}
382+
if (transcription.finished ?? false) {
383+
currentIndex = null;
384+
}
385+
setState(_scrollDown);
386+
}
387+
return currentIndex;
388+
}
389+
390+
_inputTranscriptionMessageIndex = _handleTranscription(
391+
message.inputTranscription,
392+
_inputTranscriptionMessageIndex,
393+
'Input transcription: ',
394+
true,
395+
);
396+
_outputTranscriptionMessageIndex = _handleTranscription(
397+
message.outputTranscription,
398+
_outputTranscriptionMessageIndex,
399+
'Output transcription: ',
400+
false,
401+
);
402+
356403
if (message.interrupted != null && message.interrupted!) {
357404
developer.log('Interrupted: $response');
358405
}

packages/firebase_ai/firebase_ai/example/lib/widgets/message_widget.dart

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,21 @@ class MessageData {
2222
this.fromUser,
2323
this.isThought = false,
2424
});
25+
26+
MessageData copyWith({
27+
Uint8List? imageBytes,
28+
String? text,
29+
bool? fromUser,
30+
bool? isThought,
31+
}) {
32+
return MessageData(
33+
imageBytes: imageBytes ?? this.imageBytes,
34+
text: text ?? this.text,
35+
fromUser: fromUser ?? this.fromUser,
36+
isThought: isThought ?? this.isThought,
37+
);
38+
}
39+
2540
final Uint8List? imageBytes;
2641
final String? text;
2742
final bool? fromUser;

packages/firebase_ai/firebase_ai/lib/firebase_ai.dart

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,11 +93,13 @@ export 'src/live_api.dart'
9393
show
9494
LiveGenerationConfig,
9595
SpeechConfig,
96+
AudioTranscriptionConfig,
9697
LiveServerMessage,
9798
LiveServerContent,
9899
LiveServerToolCall,
99100
LiveServerToolCallCancellation,
100-
LiveServerResponse;
101+
LiveServerResponse,
102+
Transcription;
101103
export 'src/live_session.dart' show LiveSession;
102104
export 'src/schema.dart' show Schema, SchemaType;
103105
export 'src/tool.dart'

packages/firebase_ai/firebase_ai/lib/src/live_api.dart

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,19 @@ class SpeechConfig {
7171
};
7272
}
7373

74+
/// The audio transcription configuration.
75+
class AudioTranscriptionConfig {
76+
// ignore: public_member_api_docs
77+
Map<String, Object?> toJson() => {};
78+
}
79+
7480
/// Configures live generation settings.
7581
final class LiveGenerationConfig extends BaseGenerationConfig {
7682
// ignore: public_member_api_docs
7783
LiveGenerationConfig({
7884
this.speechConfig,
85+
this.inputAudioTranscription,
86+
this.outputAudioTranscription,
7987
super.responseModalities,
8088
super.maxOutputTokens,
8189
super.temperature,
@@ -88,6 +96,13 @@ final class LiveGenerationConfig extends BaseGenerationConfig {
8896
/// The speech configuration.
8997
final SpeechConfig? speechConfig;
9098

99+
/// The transcription of the input aligns with the input audio language.
100+
final AudioTranscriptionConfig? inputAudioTranscription;
101+
102+
/// The transcription of the output aligns with the language code specified for
103+
/// the output audio.
104+
final AudioTranscriptionConfig? outputAudioTranscription;
105+
91106
@override
92107
Map<String, Object?> toJson() => {
93108
...super.toJson(),
@@ -109,14 +124,33 @@ sealed class LiveServerMessage {}
109124
/// with the live server has finished successfully.
110125
class LiveServerSetupComplete implements LiveServerMessage {}
111126

127+
/// Audio transcription message.
128+
class Transcription {
129+
// ignore: public_member_api_docs
130+
const Transcription({this.text, this.finished});
131+
132+
/// Transcription text.
133+
final String? text;
134+
135+
/// Whether this is the end of the transcription.
136+
final bool? finished;
137+
}
138+
112139
/// Content generated by the model in a live stream.
113140
class LiveServerContent implements LiveServerMessage {
114141
/// Creates a [LiveServerContent] instance.
115142
///
116143
/// [modelTurn] (optional): The content generated by the model.
117144
/// [turnComplete] (optional): Indicates if the turn is complete.
118145
/// [interrupted] (optional): Indicates if the generation was interrupted.
119-
LiveServerContent({this.modelTurn, this.turnComplete, this.interrupted});
146+
/// [inputTranscription] (optional): The input transcription.
147+
/// [outputTranscription] (optional): The output transcription.
148+
LiveServerContent(
149+
{this.modelTurn,
150+
this.turnComplete,
151+
this.interrupted,
152+
this.inputTranscription,
153+
this.outputTranscription});
120154

121155
// TODO(cynthia): Add accessor for media content
122156
/// The content generated by the model.
@@ -129,6 +163,18 @@ class LiveServerContent implements LiveServerMessage {
129163
/// Whether generation was interrupted. If true, indicates that a
130164
/// client message has interrupted current model
131165
final bool? interrupted;
166+
167+
/// The input transcription.
168+
///
169+
/// The transcription is independent to the model turn which means it doesn't
170+
/// imply any ordering between transcription and model turn.
171+
final Transcription? inputTranscription;
172+
173+
/// The output transcription.
174+
///
175+
/// The transcription is independent to the model turn which means it doesn't
176+
/// imply any ordering between transcription and model turn.
177+
final Transcription? outputTranscription;
132178
}
133179

134180
/// A tool call in a live stream.
@@ -344,7 +390,26 @@ LiveServerMessage _parseServerMessage(Object jsonObject) {
344390
if (serverContentJson.containsKey('turnComplete')) {
345391
turnComplete = serverContentJson['turnComplete'] as bool;
346392
}
347-
return LiveServerContent(modelTurn: modelTurn, turnComplete: turnComplete);
393+
final interrupted = serverContentJson['interrupted'] as bool?;
394+
Transcription? _parseTranscription(String key) {
395+
if (serverContentJson.containsKey(key)) {
396+
final transcriptionJson =
397+
serverContentJson[key] as Map<String, dynamic>;
398+
return Transcription(
399+
text: transcriptionJson['text'] as String?,
400+
finished: transcriptionJson['finished'] as bool?,
401+
);
402+
}
403+
return null;
404+
}
405+
406+
return LiveServerContent(
407+
modelTurn: modelTurn,
408+
turnComplete: turnComplete,
409+
interrupted: interrupted,
410+
inputTranscription: _parseTranscription('inputTranscription'),
411+
outputTranscription: _parseTranscription('outputTranscription'),
412+
);
348413
} else if (json.containsKey('toolCall')) {
349414
final toolContentJson = json['toolCall'] as Map<String, dynamic>;
350415
List<FunctionCall> functionCalls = [];

packages/firebase_ai/firebase_ai/lib/src/live_model.dart

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,11 +101,18 @@ final class LiveGenerativeModel extends BaseModel {
101101
final setupJson = {
102102
'setup': {
103103
'model': modelString,
104-
if (_liveGenerationConfig != null)
105-
'generation_config': _liveGenerationConfig.toJson(),
106104
if (_systemInstruction != null)
107105
'system_instruction': _systemInstruction.toJson(),
108106
if (_tools != null) 'tools': _tools.map((t) => t.toJson()).toList(),
107+
if (_liveGenerationConfig != null) ...{
108+
'generation_config': _liveGenerationConfig.toJson(),
109+
if (_liveGenerationConfig.inputAudioTranscription != null)
110+
'input_audio_transcription':
111+
_liveGenerationConfig.inputAudioTranscription!.toJson(),
112+
if (_liveGenerationConfig.outputAudioTranscription != null)
113+
'output_audio_transcription':
114+
_liveGenerationConfig.outputAudioTranscription!.toJson(),
115+
},
109116
}
110117
};
111118

packages/firebase_ai/firebase_ai/test/live_test.dart

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,5 +240,42 @@ void main() {
240240
expect(() => parseServerResponse(jsonObject),
241241
throwsA(isA<FirebaseAISdkException>()));
242242
});
243+
244+
test(
245+
'LiveGenerationConfig with transcriptions toJson() returns correct JSON',
246+
() {
247+
final liveGenerationConfig = LiveGenerationConfig(
248+
inputAudioTranscription: AudioTranscriptionConfig(),
249+
outputAudioTranscription: AudioTranscriptionConfig(),
250+
);
251+
// Explicitly, these two config should not exist in the toJson()
252+
expect(liveGenerationConfig.toJson(), {});
253+
});
254+
255+
test('parseServerMessage parses serverContent with transcriptions', () {
256+
final jsonObject = {
257+
'serverContent': {
258+
'modelTurn': {
259+
'parts': [
260+
{'text': 'Hello, world!'}
261+
]
262+
},
263+
'turnComplete': true,
264+
'inputTranscription': {'text': 'input', 'finished': true},
265+
'outputTranscription': {'text': 'output', 'finished': false}
266+
}
267+
};
268+
final response = parseServerResponse(jsonObject);
269+
expect(response.message, isA<LiveServerContent>());
270+
final contentMessage = response.message as LiveServerContent;
271+
expect(contentMessage.turnComplete, true);
272+
expect(contentMessage.modelTurn, isA<Content>());
273+
expect(contentMessage.inputTranscription, isA<Transcription>());
274+
expect(contentMessage.inputTranscription?.text, 'input');
275+
expect(contentMessage.inputTranscription?.finished, true);
276+
expect(contentMessage.outputTranscription, isA<Transcription>());
277+
expect(contentMessage.outputTranscription?.text, 'output');
278+
expect(contentMessage.outputTranscription?.finished, false);
279+
});
243280
});
244281
}

0 commit comments

Comments
 (0)