@@ -198,6 +198,45 @@ int SingleStreamDecoder::getBestStreamIndex(AVMediaType mediaType) {
198
198
// VIDEO METADATA QUERY API
199
199
// --------------------------------------------------------------------------
200
200
201
+ void SingleStreamDecoder::sortAllFrames () {
202
+ // Sort the allFrames and keyFrames vecs in each stream, and also sets
203
+ // additional fields of the FrameInfo entries like nextPts and frameIndex
204
+ // This is called at the end of a scan, or when setting a user-defined frame
205
+ // mapping.
206
+ for (auto & [streamIndex, streamInfo] : streamInfos_) {
207
+ std::sort (
208
+ streamInfo.keyFrames .begin (),
209
+ streamInfo.keyFrames .end (),
210
+ [](const FrameInfo& frameInfo1, const FrameInfo& frameInfo2) {
211
+ return frameInfo1.pts < frameInfo2.pts ;
212
+ });
213
+ std::sort (
214
+ streamInfo.allFrames .begin (),
215
+ streamInfo.allFrames .end (),
216
+ [](const FrameInfo& frameInfo1, const FrameInfo& frameInfo2) {
217
+ return frameInfo1.pts < frameInfo2.pts ;
218
+ });
219
+
220
+ size_t keyFrameIndex = 0 ;
221
+ for (size_t i = 0 ; i < streamInfo.allFrames .size (); ++i) {
222
+ streamInfo.allFrames [i].frameIndex = i;
223
+ if (streamInfo.allFrames [i].isKeyFrame ) {
224
+ TORCH_CHECK (
225
+ keyFrameIndex < streamInfo.keyFrames .size (),
226
+ " The allFrames vec claims it has MORE keyFrames than the keyFrames vec. There's a bug in torchcodec." );
227
+ streamInfo.keyFrames [keyFrameIndex].frameIndex = i;
228
+ ++keyFrameIndex;
229
+ }
230
+ if (i + 1 < streamInfo.allFrames .size ()) {
231
+ streamInfo.allFrames [i].nextPts = streamInfo.allFrames [i + 1 ].pts ;
232
+ }
233
+ }
234
+ TORCH_CHECK (
235
+ keyFrameIndex == streamInfo.keyFrames .size (),
236
+ " The allFrames vec claims it has LESS keyFrames than the keyFrames vec. There's a bug in torchcodec." );
237
+ }
238
+ }
239
+
201
240
void SingleStreamDecoder::scanFileAndUpdateMetadataAndIndex () {
202
241
if (scannedAllStreams_) {
203
242
return ;
@@ -283,40 +322,46 @@ void SingleStreamDecoder::scanFileAndUpdateMetadataAndIndex() {
283
322
getFFMPEGErrorStringFromErrorCode (status));
284
323
285
324
// Sort all frames by their pts.
286
- for (auto & [streamIndex, streamInfo] : streamInfos_) {
287
- std::sort (
288
- streamInfo.keyFrames .begin (),
289
- streamInfo.keyFrames .end (),
290
- [](const FrameInfo& frameInfo1, const FrameInfo& frameInfo2) {
291
- return frameInfo1.pts < frameInfo2.pts ;
292
- });
293
- std::sort (
294
- streamInfo.allFrames .begin (),
295
- streamInfo.allFrames .end (),
296
- [](const FrameInfo& frameInfo1, const FrameInfo& frameInfo2) {
297
- return frameInfo1.pts < frameInfo2.pts ;
298
- });
325
+ sortAllFrames ();
326
+ scannedAllStreams_ = true ;
327
+ }
299
328
300
- size_t keyFrameIndex = 0 ;
301
- for (size_t i = 0 ; i < streamInfo.allFrames .size (); ++i) {
302
- streamInfo.allFrames [i].frameIndex = i;
303
- if (streamInfo.allFrames [i].isKeyFrame ) {
304
- TORCH_CHECK (
305
- keyFrameIndex < streamInfo.keyFrames .size (),
306
- " The allFrames vec claims it has MORE keyFrames than the keyFrames vec. There's a bug in torchcodec." );
307
- streamInfo.keyFrames [keyFrameIndex].frameIndex = i;
308
- ++keyFrameIndex;
309
- }
310
- if (i + 1 < streamInfo.allFrames .size ()) {
311
- streamInfo.allFrames [i].nextPts = streamInfo.allFrames [i + 1 ].pts ;
312
- }
329
+ void SingleStreamDecoder::readCustomFrameMappingsUpdateMetadataAndIndex (
330
+ int streamIndex,
331
+ FrameMappings customFrameMappings) {
332
+ auto & all_frames = customFrameMappings.all_frames ;
333
+ auto & is_key_frame = customFrameMappings.is_key_frame ;
334
+ auto & duration = customFrameMappings.duration ;
335
+ TORCH_CHECK (
336
+ all_frames.size (0 ) == is_key_frame.size (0 ) &&
337
+ is_key_frame.size (0 ) == duration.size (0 ),
338
+ " all_frames, is_key_frame, and duration from custom_frame_mappings were not same size." );
339
+
340
+ auto & streamMetadata = containerMetadata_.allStreamMetadata [streamIndex];
341
+
342
+ streamMetadata.beginStreamPtsFromContent = all_frames[0 ].item <int64_t >();
343
+ streamMetadata.endStreamPtsFromContent =
344
+ all_frames[-1 ].item <int64_t >() + duration[-1 ].item <int64_t >();
345
+
346
+ auto avStream = formatContext_->streams [streamIndex];
347
+ streamMetadata.beginStreamPtsSecondsFromContent =
348
+ *streamMetadata.beginStreamPtsFromContent * av_q2d (avStream->time_base );
349
+
350
+ streamMetadata.endStreamPtsSecondsFromContent =
351
+ *streamMetadata.endStreamPtsFromContent * av_q2d (avStream->time_base );
352
+
353
+ streamMetadata.numFramesFromContent = all_frames.size (0 );
354
+ for (int64_t i = 0 ; i < all_frames.size (0 ); ++i) {
355
+ FrameInfo frameInfo;
356
+ frameInfo.pts = all_frames[i].item <int64_t >();
357
+ frameInfo.isKeyFrame = is_key_frame[i].item <bool >();
358
+ streamInfos_[streamIndex].allFrames .push_back (frameInfo);
359
+ if (frameInfo.isKeyFrame ) {
360
+ streamInfos_[streamIndex].keyFrames .push_back (frameInfo);
313
361
}
314
- TORCH_CHECK (
315
- keyFrameIndex == streamInfo.keyFrames .size (),
316
- " The allFrames vec claims it has LESS keyFrames than the keyFrames vec. There's a bug in torchcodec." );
317
362
}
318
-
319
- scannedAllStreams_ = true ;
363
+ // Sort all frames by their pts
364
+ sortAllFrames () ;
320
365
}
321
366
322
367
ContainerMetadata SingleStreamDecoder::getContainerMetadata () const {
@@ -431,7 +476,8 @@ void SingleStreamDecoder::addStream(
431
476
432
477
void SingleStreamDecoder::addVideoStream (
433
478
int streamIndex,
434
- const VideoStreamOptions& videoStreamOptions) {
479
+ const VideoStreamOptions& videoStreamOptions,
480
+ std::optional<FrameMappings> customFrameMappings) {
435
481
addStream (
436
482
streamIndex,
437
483
AVMEDIA_TYPE_VIDEO,
@@ -456,6 +502,14 @@ void SingleStreamDecoder::addVideoStream(
456
502
streamMetadata.height = streamInfo.codecContext ->height ;
457
503
streamMetadata.sampleAspectRatio =
458
504
streamInfo.codecContext ->sample_aspect_ratio ;
505
+
506
+ if (seekMode_ == SeekMode::custom_frame_mappings) {
507
+ TORCH_CHECK (
508
+ customFrameMappings.has_value (),
509
+ " Please provide frame mappings when using custom_frame_mappings seek mode." );
510
+ readCustomFrameMappingsUpdateMetadataAndIndex (
511
+ streamIndex, customFrameMappings.value ());
512
+ }
459
513
}
460
514
461
515
void SingleStreamDecoder::addAudioStream (
@@ -1407,6 +1461,7 @@ int SingleStreamDecoder::getKeyFrameIndexForPtsUsingScannedIndex(
1407
1461
int64_t SingleStreamDecoder::secondsToIndexLowerBound (double seconds) {
1408
1462
auto & streamInfo = streamInfos_[activeStreamIndex_];
1409
1463
switch (seekMode_) {
1464
+ case SeekMode::custom_frame_mappings:
1410
1465
case SeekMode::exact: {
1411
1466
auto frame = std::lower_bound (
1412
1467
streamInfo.allFrames .begin (),
@@ -1434,6 +1489,7 @@ int64_t SingleStreamDecoder::secondsToIndexLowerBound(double seconds) {
1434
1489
int64_t SingleStreamDecoder::secondsToIndexUpperBound (double seconds) {
1435
1490
auto & streamInfo = streamInfos_[activeStreamIndex_];
1436
1491
switch (seekMode_) {
1492
+ case SeekMode::custom_frame_mappings:
1437
1493
case SeekMode::exact: {
1438
1494
auto frame = std::upper_bound (
1439
1495
streamInfo.allFrames .begin (),
@@ -1461,6 +1517,7 @@ int64_t SingleStreamDecoder::secondsToIndexUpperBound(double seconds) {
1461
1517
int64_t SingleStreamDecoder::getPts (int64_t frameIndex) {
1462
1518
auto & streamInfo = streamInfos_[activeStreamIndex_];
1463
1519
switch (seekMode_) {
1520
+ case SeekMode::custom_frame_mappings:
1464
1521
case SeekMode::exact:
1465
1522
return streamInfo.allFrames [frameIndex].pts ;
1466
1523
case SeekMode::approximate: {
@@ -1485,6 +1542,7 @@ int64_t SingleStreamDecoder::getPts(int64_t frameIndex) {
1485
1542
std::optional<int64_t > SingleStreamDecoder::getNumFrames (
1486
1543
const StreamMetadata& streamMetadata) {
1487
1544
switch (seekMode_) {
1545
+ case SeekMode::custom_frame_mappings:
1488
1546
case SeekMode::exact:
1489
1547
return streamMetadata.numFramesFromContent .value ();
1490
1548
case SeekMode::approximate: {
@@ -1498,6 +1556,7 @@ std::optional<int64_t> SingleStreamDecoder::getNumFrames(
1498
1556
double SingleStreamDecoder::getMinSeconds (
1499
1557
const StreamMetadata& streamMetadata) {
1500
1558
switch (seekMode_) {
1559
+ case SeekMode::custom_frame_mappings:
1501
1560
case SeekMode::exact:
1502
1561
return streamMetadata.beginStreamPtsSecondsFromContent .value ();
1503
1562
case SeekMode::approximate:
@@ -1510,6 +1569,7 @@ double SingleStreamDecoder::getMinSeconds(
1510
1569
std::optional<double > SingleStreamDecoder::getMaxSeconds (
1511
1570
const StreamMetadata& streamMetadata) {
1512
1571
switch (seekMode_) {
1572
+ case SeekMode::custom_frame_mappings:
1513
1573
case SeekMode::exact:
1514
1574
return streamMetadata.endStreamPtsSecondsFromContent .value ();
1515
1575
case SeekMode::approximate: {
@@ -1645,6 +1705,8 @@ SingleStreamDecoder::SeekMode seekModeFromString(std::string_view seekMode) {
1645
1705
return SingleStreamDecoder::SeekMode::exact;
1646
1706
} else if (seekMode == " approximate" ) {
1647
1707
return SingleStreamDecoder::SeekMode::approximate;
1708
+ } else if (seekMode == " custom_frame_mappings" ) {
1709
+ return SingleStreamDecoder::SeekMode::custom_frame_mappings;
1648
1710
} else {
1649
1711
TORCH_CHECK (false , " Invalid seek mode: " + std::string (seekMode));
1650
1712
}
0 commit comments