Skip to content

Commit f2d0b14

Browse files
Dan-FloresDaniel Flores
andauthored
Add stream_index seek mode, read frame index and update metadata (#764)
Co-authored-by: Daniel Flores <[email protected]>
1 parent 001f091 commit f2d0b14

File tree

7 files changed

+292
-46
lines changed

7 files changed

+292
-46
lines changed

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 94 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,45 @@ int SingleStreamDecoder::getBestStreamIndex(AVMediaType mediaType) {
198198
// VIDEO METADATA QUERY API
199199
// --------------------------------------------------------------------------
200200

201+
void SingleStreamDecoder::sortAllFrames() {
202+
// Sort the allFrames and keyFrames vecs in each stream, and also sets
203+
// additional fields of the FrameInfo entries like nextPts and frameIndex
204+
// This is called at the end of a scan, or when setting a user-defined frame
205+
// mapping.
206+
for (auto& [streamIndex, streamInfo] : streamInfos_) {
207+
std::sort(
208+
streamInfo.keyFrames.begin(),
209+
streamInfo.keyFrames.end(),
210+
[](const FrameInfo& frameInfo1, const FrameInfo& frameInfo2) {
211+
return frameInfo1.pts < frameInfo2.pts;
212+
});
213+
std::sort(
214+
streamInfo.allFrames.begin(),
215+
streamInfo.allFrames.end(),
216+
[](const FrameInfo& frameInfo1, const FrameInfo& frameInfo2) {
217+
return frameInfo1.pts < frameInfo2.pts;
218+
});
219+
220+
size_t keyFrameIndex = 0;
221+
for (size_t i = 0; i < streamInfo.allFrames.size(); ++i) {
222+
streamInfo.allFrames[i].frameIndex = i;
223+
if (streamInfo.allFrames[i].isKeyFrame) {
224+
TORCH_CHECK(
225+
keyFrameIndex < streamInfo.keyFrames.size(),
226+
"The allFrames vec claims it has MORE keyFrames than the keyFrames vec. There's a bug in torchcodec.");
227+
streamInfo.keyFrames[keyFrameIndex].frameIndex = i;
228+
++keyFrameIndex;
229+
}
230+
if (i + 1 < streamInfo.allFrames.size()) {
231+
streamInfo.allFrames[i].nextPts = streamInfo.allFrames[i + 1].pts;
232+
}
233+
}
234+
TORCH_CHECK(
235+
keyFrameIndex == streamInfo.keyFrames.size(),
236+
"The allFrames vec claims it has LESS keyFrames than the keyFrames vec. There's a bug in torchcodec.");
237+
}
238+
}
239+
201240
void SingleStreamDecoder::scanFileAndUpdateMetadataAndIndex() {
202241
if (scannedAllStreams_) {
203242
return;
@@ -283,40 +322,46 @@ void SingleStreamDecoder::scanFileAndUpdateMetadataAndIndex() {
283322
getFFMPEGErrorStringFromErrorCode(status));
284323

285324
// Sort all frames by their pts.
286-
for (auto& [streamIndex, streamInfo] : streamInfos_) {
287-
std::sort(
288-
streamInfo.keyFrames.begin(),
289-
streamInfo.keyFrames.end(),
290-
[](const FrameInfo& frameInfo1, const FrameInfo& frameInfo2) {
291-
return frameInfo1.pts < frameInfo2.pts;
292-
});
293-
std::sort(
294-
streamInfo.allFrames.begin(),
295-
streamInfo.allFrames.end(),
296-
[](const FrameInfo& frameInfo1, const FrameInfo& frameInfo2) {
297-
return frameInfo1.pts < frameInfo2.pts;
298-
});
325+
sortAllFrames();
326+
scannedAllStreams_ = true;
327+
}
299328

300-
size_t keyFrameIndex = 0;
301-
for (size_t i = 0; i < streamInfo.allFrames.size(); ++i) {
302-
streamInfo.allFrames[i].frameIndex = i;
303-
if (streamInfo.allFrames[i].isKeyFrame) {
304-
TORCH_CHECK(
305-
keyFrameIndex < streamInfo.keyFrames.size(),
306-
"The allFrames vec claims it has MORE keyFrames than the keyFrames vec. There's a bug in torchcodec.");
307-
streamInfo.keyFrames[keyFrameIndex].frameIndex = i;
308-
++keyFrameIndex;
309-
}
310-
if (i + 1 < streamInfo.allFrames.size()) {
311-
streamInfo.allFrames[i].nextPts = streamInfo.allFrames[i + 1].pts;
312-
}
329+
void SingleStreamDecoder::readCustomFrameMappingsUpdateMetadataAndIndex(
330+
int streamIndex,
331+
FrameMappings customFrameMappings) {
332+
auto& all_frames = customFrameMappings.all_frames;
333+
auto& is_key_frame = customFrameMappings.is_key_frame;
334+
auto& duration = customFrameMappings.duration;
335+
TORCH_CHECK(
336+
all_frames.size(0) == is_key_frame.size(0) &&
337+
is_key_frame.size(0) == duration.size(0),
338+
"all_frames, is_key_frame, and duration from custom_frame_mappings were not same size.");
339+
340+
auto& streamMetadata = containerMetadata_.allStreamMetadata[streamIndex];
341+
342+
streamMetadata.beginStreamPtsFromContent = all_frames[0].item<int64_t>();
343+
streamMetadata.endStreamPtsFromContent =
344+
all_frames[-1].item<int64_t>() + duration[-1].item<int64_t>();
345+
346+
auto avStream = formatContext_->streams[streamIndex];
347+
streamMetadata.beginStreamPtsSecondsFromContent =
348+
*streamMetadata.beginStreamPtsFromContent * av_q2d(avStream->time_base);
349+
350+
streamMetadata.endStreamPtsSecondsFromContent =
351+
*streamMetadata.endStreamPtsFromContent * av_q2d(avStream->time_base);
352+
353+
streamMetadata.numFramesFromContent = all_frames.size(0);
354+
for (int64_t i = 0; i < all_frames.size(0); ++i) {
355+
FrameInfo frameInfo;
356+
frameInfo.pts = all_frames[i].item<int64_t>();
357+
frameInfo.isKeyFrame = is_key_frame[i].item<bool>();
358+
streamInfos_[streamIndex].allFrames.push_back(frameInfo);
359+
if (frameInfo.isKeyFrame) {
360+
streamInfos_[streamIndex].keyFrames.push_back(frameInfo);
313361
}
314-
TORCH_CHECK(
315-
keyFrameIndex == streamInfo.keyFrames.size(),
316-
"The allFrames vec claims it has LESS keyFrames than the keyFrames vec. There's a bug in torchcodec.");
317362
}
318-
319-
scannedAllStreams_ = true;
363+
// Sort all frames by their pts
364+
sortAllFrames();
320365
}
321366

322367
ContainerMetadata SingleStreamDecoder::getContainerMetadata() const {
@@ -431,7 +476,8 @@ void SingleStreamDecoder::addStream(
431476

432477
void SingleStreamDecoder::addVideoStream(
433478
int streamIndex,
434-
const VideoStreamOptions& videoStreamOptions) {
479+
const VideoStreamOptions& videoStreamOptions,
480+
std::optional<FrameMappings> customFrameMappings) {
435481
addStream(
436482
streamIndex,
437483
AVMEDIA_TYPE_VIDEO,
@@ -456,6 +502,14 @@ void SingleStreamDecoder::addVideoStream(
456502
streamMetadata.height = streamInfo.codecContext->height;
457503
streamMetadata.sampleAspectRatio =
458504
streamInfo.codecContext->sample_aspect_ratio;
505+
506+
if (seekMode_ == SeekMode::custom_frame_mappings) {
507+
TORCH_CHECK(
508+
customFrameMappings.has_value(),
509+
"Please provide frame mappings when using custom_frame_mappings seek mode.");
510+
readCustomFrameMappingsUpdateMetadataAndIndex(
511+
streamIndex, customFrameMappings.value());
512+
}
459513
}
460514

461515
void SingleStreamDecoder::addAudioStream(
@@ -1407,6 +1461,7 @@ int SingleStreamDecoder::getKeyFrameIndexForPtsUsingScannedIndex(
14071461
int64_t SingleStreamDecoder::secondsToIndexLowerBound(double seconds) {
14081462
auto& streamInfo = streamInfos_[activeStreamIndex_];
14091463
switch (seekMode_) {
1464+
case SeekMode::custom_frame_mappings:
14101465
case SeekMode::exact: {
14111466
auto frame = std::lower_bound(
14121467
streamInfo.allFrames.begin(),
@@ -1434,6 +1489,7 @@ int64_t SingleStreamDecoder::secondsToIndexLowerBound(double seconds) {
14341489
int64_t SingleStreamDecoder::secondsToIndexUpperBound(double seconds) {
14351490
auto& streamInfo = streamInfos_[activeStreamIndex_];
14361491
switch (seekMode_) {
1492+
case SeekMode::custom_frame_mappings:
14371493
case SeekMode::exact: {
14381494
auto frame = std::upper_bound(
14391495
streamInfo.allFrames.begin(),
@@ -1461,6 +1517,7 @@ int64_t SingleStreamDecoder::secondsToIndexUpperBound(double seconds) {
14611517
int64_t SingleStreamDecoder::getPts(int64_t frameIndex) {
14621518
auto& streamInfo = streamInfos_[activeStreamIndex_];
14631519
switch (seekMode_) {
1520+
case SeekMode::custom_frame_mappings:
14641521
case SeekMode::exact:
14651522
return streamInfo.allFrames[frameIndex].pts;
14661523
case SeekMode::approximate: {
@@ -1485,6 +1542,7 @@ int64_t SingleStreamDecoder::getPts(int64_t frameIndex) {
14851542
std::optional<int64_t> SingleStreamDecoder::getNumFrames(
14861543
const StreamMetadata& streamMetadata) {
14871544
switch (seekMode_) {
1545+
case SeekMode::custom_frame_mappings:
14881546
case SeekMode::exact:
14891547
return streamMetadata.numFramesFromContent.value();
14901548
case SeekMode::approximate: {
@@ -1498,6 +1556,7 @@ std::optional<int64_t> SingleStreamDecoder::getNumFrames(
14981556
double SingleStreamDecoder::getMinSeconds(
14991557
const StreamMetadata& streamMetadata) {
15001558
switch (seekMode_) {
1559+
case SeekMode::custom_frame_mappings:
15011560
case SeekMode::exact:
15021561
return streamMetadata.beginStreamPtsSecondsFromContent.value();
15031562
case SeekMode::approximate:
@@ -1510,6 +1569,7 @@ double SingleStreamDecoder::getMinSeconds(
15101569
std::optional<double> SingleStreamDecoder::getMaxSeconds(
15111570
const StreamMetadata& streamMetadata) {
15121571
switch (seekMode_) {
1572+
case SeekMode::custom_frame_mappings:
15131573
case SeekMode::exact:
15141574
return streamMetadata.endStreamPtsSecondsFromContent.value();
15151575
case SeekMode::approximate: {
@@ -1645,6 +1705,8 @@ SingleStreamDecoder::SeekMode seekModeFromString(std::string_view seekMode) {
16451705
return SingleStreamDecoder::SeekMode::exact;
16461706
} else if (seekMode == "approximate") {
16471707
return SingleStreamDecoder::SeekMode::approximate;
1708+
} else if (seekMode == "custom_frame_mappings") {
1709+
return SingleStreamDecoder::SeekMode::custom_frame_mappings;
16481710
} else {
16491711
TORCH_CHECK(false, "Invalid seek mode: " + std::string(seekMode));
16501712
}

src/torchcodec/_core/SingleStreamDecoder.h

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class SingleStreamDecoder {
2929
// CONSTRUCTION API
3030
// --------------------------------------------------------------------------
3131

32-
enum class SeekMode { exact, approximate };
32+
enum class SeekMode { exact, approximate, custom_frame_mappings };
3333

3434
// Creates a SingleStreamDecoder from the video at videoFilePath.
3535
explicit SingleStreamDecoder(
@@ -53,20 +53,38 @@ class SingleStreamDecoder {
5353
// the allFrames and keyFrames vectors.
5454
void scanFileAndUpdateMetadataAndIndex();
5555

56+
// Sorts the keyFrames and allFrames vectors in each StreamInfo by pts.
57+
void sortAllFrames();
58+
5659
// Returns the metadata for the container.
5760
ContainerMetadata getContainerMetadata() const;
5861

5962
// Returns the key frame indices as a tensor. The tensor is 1D and contains
6063
// int64 values, where each value is the frame index for a key frame.
6164
torch::Tensor getKeyFrameIndices();
6265

66+
// FrameMappings is used for the custom_frame_mappings seek mode to store
67+
// metadata of frames in a stream. The size of all tensors in this struct must
68+
// match.
69+
6370
// --------------------------------------------------------------------------
6471
// ADDING STREAMS API
6572
// --------------------------------------------------------------------------
73+
struct FrameMappings {
74+
// 1D tensor of int64, each value is the PTS of a frame in timebase units.
75+
torch::Tensor all_frames;
76+
// 1D tensor of bool, each value indicates if the corresponding frame in
77+
// all_frames is a key frame.
78+
torch::Tensor is_key_frame;
79+
// 1D tensor of int64, each value is the duration of the corresponding frame
80+
// in all_frames in timebase units.
81+
torch::Tensor duration;
82+
};
6683

6784
void addVideoStream(
6885
int streamIndex,
69-
const VideoStreamOptions& videoStreamOptions = VideoStreamOptions());
86+
const VideoStreamOptions& videoStreamOptions = VideoStreamOptions(),
87+
std::optional<FrameMappings> customFrameMappings = std::nullopt);
7088
void addAudioStream(
7189
int streamIndex,
7290
const AudioStreamOptions& audioStreamOptions = AudioStreamOptions());
@@ -226,6 +244,13 @@ class SingleStreamDecoder {
226244
// --------------------------------------------------------------------------
227245

228246
void initializeDecoder();
247+
248+
// Reads the user provided frame index and updates each StreamInfo's index,
249+
// i.e. the allFrames and keyFrames vectors, and
250+
// endStreamPtsSecondsFromContent
251+
void readCustomFrameMappingsUpdateMetadataAndIndex(
252+
int streamIndex,
253+
FrameMappings customFrameMappings);
229254
// --------------------------------------------------------------------------
230255
// DECODING APIS AND RELATED UTILS
231256
// --------------------------------------------------------------------------

src/torchcodec/_core/custom_ops.cpp

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ TORCH_LIBRARY(torchcodec_ns, m) {
3636
"create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor");
3737
m.def("_convert_to_tensor(int decoder_ptr) -> Tensor");
3838
m.def(
39-
"_add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str? device=None, str? color_conversion_library=None) -> ()");
39+
"_add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str? device=None, (Tensor, Tensor, Tensor)? custom_frame_mappings=None, str? color_conversion_library=None) -> ()");
4040
m.def(
41-
"add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str? device=None) -> ()");
41+
"add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str? device=None, (Tensor, Tensor, Tensor)? custom_frame_mappings=None) -> ()");
4242
m.def(
4343
"add_audio_stream(Tensor(a!) decoder, *, int? stream_index=None, int? sample_rate=None, int? num_channels=None) -> ()");
4444
m.def("seek_to_pts(Tensor(a!) decoder, float seconds) -> ()");
@@ -105,6 +105,14 @@ OpsFrameOutput makeOpsFrameOutput(FrameOutput& frame) {
105105
torch::tensor(frame.durationSeconds, torch::dtype(torch::kFloat64)));
106106
}
107107

108+
SingleStreamDecoder::FrameMappings makeFrameMappings(
109+
std::tuple<at::Tensor, at::Tensor, at::Tensor> custom_frame_mappings) {
110+
return SingleStreamDecoder::FrameMappings{
111+
std::get<0>(custom_frame_mappings),
112+
std::get<1>(custom_frame_mappings),
113+
std::get<2>(custom_frame_mappings)};
114+
}
115+
108116
// All elements of this tuple are tensors of the same leading dimension. The
109117
// tuple represents the frames for N total frames, where N is the dimension of
110118
// each stacked tensor. The elments are:
@@ -223,6 +231,8 @@ void _add_video_stream(
223231
std::optional<std::string_view> dimension_order = std::nullopt,
224232
std::optional<int64_t> stream_index = std::nullopt,
225233
std::optional<std::string_view> device = std::nullopt,
234+
std::optional<std::tuple<at::Tensor, at::Tensor, at::Tensor>>
235+
custom_frame_mappings = std::nullopt,
226236
std::optional<std::string_view> color_conversion_library = std::nullopt) {
227237
VideoStreamOptions videoStreamOptions;
228238
videoStreamOptions.width = width;
@@ -253,9 +263,13 @@ void _add_video_stream(
253263
if (device.has_value()) {
254264
videoStreamOptions.device = createTorchDevice(std::string(device.value()));
255265
}
256-
266+
std::optional<SingleStreamDecoder::FrameMappings> converted_mappings =
267+
custom_frame_mappings.has_value()
268+
? std::make_optional(makeFrameMappings(custom_frame_mappings.value()))
269+
: std::nullopt;
257270
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
258-
videoDecoder->addVideoStream(stream_index.value_or(-1), videoStreamOptions);
271+
videoDecoder->addVideoStream(
272+
stream_index.value_or(-1), videoStreamOptions, converted_mappings);
259273
}
260274

261275
// Add a new video stream at `stream_index` using the provided options.
@@ -266,15 +280,18 @@ void add_video_stream(
266280
std::optional<int64_t> num_threads = std::nullopt,
267281
std::optional<std::string_view> dimension_order = std::nullopt,
268282
std::optional<int64_t> stream_index = std::nullopt,
269-
std::optional<std::string_view> device = std::nullopt) {
283+
std::optional<std::string_view> device = std::nullopt,
284+
std::optional<std::tuple<at::Tensor, at::Tensor, at::Tensor>>
285+
custom_frame_mappings = std::nullopt) {
270286
_add_video_stream(
271287
decoder,
272288
width,
273289
height,
274290
num_threads,
275291
dimension_order,
276292
stream_index,
277-
device);
293+
device,
294+
custom_frame_mappings);
278295
}
279296

280297
void add_audio_stream(

src/torchcodec/_core/ops.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,9 @@ def _add_video_stream_abstract(
205205
dimension_order: Optional[str] = None,
206206
stream_index: Optional[int] = None,
207207
device: Optional[str] = None,
208+
custom_frame_mappings: Optional[
209+
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
210+
] = None,
208211
color_conversion_library: Optional[str] = None,
209212
) -> None:
210213
return
@@ -220,6 +223,9 @@ def add_video_stream_abstract(
220223
dimension_order: Optional[str] = None,
221224
stream_index: Optional[int] = None,
222225
device: Optional[str] = None,
226+
custom_frame_mappings: Optional[
227+
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
228+
] = None,
223229
) -> None:
224230
return
225231

0 commit comments

Comments
 (0)