Skip to content

Commit

Permalink
UCF backbone overwrite (#75)
Browse files Browse the repository at this point in the history
* Backbone now loaded from path in UCF detector config rather than train config

* Updated ad hoc detector unit tests, added one script that runs all tests

* Removed deprecated mining model weights directory

* UCFDetector config no longer provides backbone path, instead uses training config backbone path which defaults to xception-best.pth on HF

* Updated backbone load check

* Training now uses xception-best.pth backbone weights from bitmind/bm-ucf/ on HuggingFace
  • Loading branch information
aliang322 authored Oct 1, 2024
1 parent a0b0e6a commit dfa1614
Show file tree
Hide file tree
Showing 10 changed files with 27 additions and 14 deletions.
2 changes: 1 addition & 1 deletion base_miner/UCF/config/ucf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
log_dir: ../debug_logs/ucf

# model setting
pretrained: ../weights/xception_best.pth # path to a pre-trained model, if using one
pretrained: ../weights/xception-best.pth # path to a pre-trained model, if using one
model_name: ucf # model name
backbone_name: xception # backbone name
encoder_feat_dim: 512 # feature dimension of the backbone
Expand Down
2 changes: 1 addition & 1 deletion base_miner/UCF/train_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ def main():
ensure_backbone_is_available(
logger=logger,
model_filename=config['pretrained'].split('/')[-1],
hugging_face_repo_name='bitmind/' + config['model_name']
hugging_face_repo_name='bitmind/bm-ucf'
)

# prepare the model (detector)
Expand Down
3 changes: 1 addition & 2 deletions base_miner/deepfake_detectors/configs/ucf.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# UCFDetector Generalist Configuration
hf_repo: 'bitmind/bm-ucf' # Hugging Face repository for downloading model files
train_config: 'bm-general-config.yaml' # pre-trained configuration file in HuggingFace
weights: 'bm-general.pth' # UCF model checkpoint in HuggingFace
backbone_weights: 'xception-best.pth' # backbone model checkpoint in HuggingFace
weights: 'bm-general.pth' # UCF model checkpoint in HuggingFace
1 change: 0 additions & 1 deletion base_miner/deepfake_detectors/configs/ucf_face.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,3 @@
hf_repo: 'bitmind/bm-ucf' # Hugging Face repository for downloading model files
train_config: 'bm-faces-config.yaml' # pre-trained configuration file in HuggingFace
weights: 'bm-faces.pth' # UCF model checkpoint in HuggingFace
backbone_weights: 'xception-best.pth' # backbone model checkpoint in HuggingFace
2 changes: 1 addition & 1 deletion base_miner/deepfake_detectors/ucf_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def load_model(self):
self.init_cudnn()
self.init_seed()
self.ensure_weights_are_available(self.weights)
self.ensure_weights_are_available(self.backbone_weights)
self.ensure_weights_are_available(self.train_config['pretrained'].split('/')[-1])
model_class = DETECTOR[self.train_config['model_name']]
bt.logging.info(f"Loaded config from training run: {self.train_config}")
self.model = model_class(self.train_config).to(self.device)
Expand Down
18 changes: 18 additions & 0 deletions base_miner/deepfake_detectors/unit_tests/run_all_unit_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import os
import subprocess

def run_all_py_scripts(directory):
# List all files in the directory
for filename in os.listdir(directory):
# Check if the file ends with .py and is not this script itself
if filename.endswith('.py') and filename != os.path.basename(__file__):
# Full path of the python file
filepath = os.path.join(directory, filename)
print(f"Running {filename}...")

# Run the script using subprocess
subprocess.run(['python', filepath])

if __name__ == "__main__":
# Run all python files in the current directory
run_all_py_scripts(os.getcwd())
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ def test_load_model(self):

def test_load_gates(self):
"""Test if the models load properly with the given weight paths."""
self.assertIsNotNone(self.camo_detector.gate, "Gate should not be None")
self.assertIsNotNone(self.camo_detector.gate.gates, "GatingMechanism gates not be None")
self.assertIsNotNone(self.camo_detector.gating_mechanism, "GatingMechanism gates not be None")

def test_call(self):
"""Test the __call__ method for inference on a given image."""
Expand Down
4 changes: 2 additions & 2 deletions base_miner/deepfake_detectors/unit_tests/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
class TestDetectorRegistry(unittest.TestCase):

def test_registry_contents(self):
from detector_registry import Registry
from base_miner.registry import Registry
detector_registry = Registry()
# Check if the registry has the expected keys (class names or custom names)
registered_keys = list(detector_registry.data.keys())
Expand All @@ -25,7 +25,7 @@ def test_registry_contents(self):
self.assertEqual(len(registered_keys), 0, "There should be no registered detectors.")

def test_registry_contents_after_import(self):
from base_miner.deepfake_detectors import DETECTOR_REGISTRY
from base_miner import DETECTOR_REGISTRY
# Check if the registry has the expected keys (class names or custom names)
registered_keys = list(DETECTOR_REGISTRY.data.keys())

Expand Down
6 changes: 2 additions & 4 deletions base_miner/deepfake_detectors/unit_tests/test_ucf_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,17 @@ def test_ensure_weights(self):
"""Test if the weights are checked and downloaded if missing."""
self.assertTrue((Path(WEIGHTS_DIR) / self.ucf_detector.weights).exists(),
"Model weights should be available after initialization.")
self.assertTrue((Path(WEIGHTS_DIR) / self.ucf_detector.backbone_weights).exists(),
self.assertTrue((Path(WEIGHTS_DIR) / self.ucf_detector.train_config['pretrained'].split('/')[-1]).exists(),
"Backbone weights should be available after initialization.")
self.assertTrue((Path(WEIGHTS_DIR) / self.ucf_detector_face.weights).exists(),
"Face model weights should be available after initialization.")
self.assertTrue((Path(WEIGHTS_DIR) / self.ucf_detector_face.backbone_weights).exists(),
self.assertTrue((Path(WEIGHTS_DIR) / self.ucf_detector_face.train_config['pretrained'].split('/')[-1]).exists(),
"Face backbone weights should be available after initialization.")

def test_model_loading(self):
"""Test if the model is loaded properly."""
self.assertIsNotNone(self.ucf_detector.model, "Generalist model should not be None")
self.assertIsNone(self.ucf_detector.gate, "Generalist gate should be None")
self.assertIsNotNone(self.ucf_detector_face.model, "Face model should not be None")
self.assertIsNotNone(self.ucf_detector_face.gate, "Face gate should not be None")

def test_infer_general(self):
"""Test a basic inference to ensure model outputs are correct."""
Expand Down
Binary file removed mining_models/base.pth
Binary file not shown.

0 comments on commit dfa1614

Please sign in to comment.