Skip to content

Commit b17fe0d

Browse files
No public description
PiperOrigin-RevId: 814387436
1 parent 85743be commit b17fe0d

File tree

2 files changed

+99
-0
lines changed

2 files changed

+99
-0
lines changed
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Copyright 2025 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Classifies images based on Image Classifier."""
16+
17+
import glob
18+
import os
19+
import shutil
20+
21+
from absl import flags
22+
import torch
23+
24+
from official.projects.waste_identification_ml.llm_applications.milk_pouch_detection import models
25+
26+
27+
FLAGS = flags.FLAGS
28+
flags.DEFINE_string(
29+
"input_dir", "tempdir", "Directory containing image frames to process."
30+
)
31+
32+
33+
# Path to the custom trained model for Image Classifier.
34+
IMAGE_CLASSIFIER_WEIGHTS = (
35+
"milk_pouch_project/image_classifier_model/best_vit_model_epoch_131.pt"
36+
)
37+
CLASS_NAMES = ["dairy", "other"]
38+
39+
40+
def main(_) -> None:
41+
os.makedirs("dairy_packets", exist_ok=True)
42+
os.makedirs("others", exist_ok=True)
43+
44+
classifier = models.ImageClassifier(
45+
model_path=IMAGE_CLASSIFIER_WEIGHTS,
46+
class_names=CLASS_NAMES,
47+
device="cuda" if torch.cuda.is_available() else "cpu",
48+
)
49+
50+
files = glob.glob(os.path.join(FLAGS.input_dir, "*"))
51+
print(f"Found {len(files)} images to process...")
52+
53+
total_dairy_packets = 0
54+
for path in files:
55+
pred_class, _ = classifier.classify(path)
56+
if pred_class == "dairy":
57+
total_dairy_packets += 1
58+
shutil.move(path, os.path.join("dairy_packets", os.path.basename(path)))
59+
else:
60+
shutil.move(path, os.path.join("others", os.path.basename(path)))
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#!/bin/bash
2+
# --- Parse command-line arguments ---
3+
while [[ "$#" -gt 0 ]]; do
4+
case $1 in
5+
--input_dir=*) input_dir="${1#*=}"; shift ;;
6+
--category_name=*) category_name="${1#*=}"; shift ;;
7+
*) echo "❌ Unknown parameter passed: $1"; exit 1 ;;
8+
esac
9+
done
10+
11+
# --- Check if required arguments were provided ---
12+
if [ -z "$input_dir" ]; then
13+
echo "❌ Error: --input_dir must be specified"
14+
echo "✅ Usage: ./run_pipeline.sh --input_dir=/path/to/images --category_name=category"
15+
exit 1
16+
fi
17+
18+
if [ -z "$category_name" ]; then
19+
echo "❌ Error: --category_name must be specified"
20+
echo "✅ Usage: ./run_pipeline.sh --input_dir=/path/to/images --category_name=category"
21+
exit 1
22+
fi
23+
24+
# --- Run pipeline ---
25+
echo "✅ Activating virtual environment..."
26+
source myenv/bin/activate
27+
28+
echo "🚀 Running detect_and_segment.py..."
29+
echo " Input directory: $input_dir"
30+
echo " Category name: $category_name"
31+
python3 detect_and_segment.py --input_dir="$input_dir" --category_name="$category_name"
32+
33+
echo "🧠 Running classify.py..."
34+
python3 classify.py
35+
36+
echo "🧹 Deactivating virtual environment..."
37+
deactivate
38+
39+
echo "✅ Done."

0 commit comments

Comments
 (0)