diff --git a/apps/m3u/tasks.py b/apps/m3u/tasks.py index 2617e7ea..af995d96 100644 --- a/apps/m3u/tasks.py +++ b/apps/m3u/tasks.py @@ -5,6 +5,7 @@ import requests import os import gc import gzip, zipfile +from concurrent.futures import ThreadPoolExecutor, as_completed from celery.app.control import Inspect from celery.result import AsyncResult from celery import shared_task, current_app, group @@ -31,7 +32,7 @@ from core.utils import send_websocket_update logger = logging.getLogger(__name__) -BATCH_SIZE = 1000 +BATCH_SIZE = 1500 # Optimized batch size for threading m3u_dir = os.path.join(settings.MEDIA_ROOT, "cached_m3u") @@ -314,8 +315,17 @@ def process_groups(account, groups): ChannelGroupM3UAccount.objects.bulk_create(relations, ignore_conflicts=True) -@shared_task def process_xc_category(account_id, batch, groups, hash_keys): + """Legacy Celery task wrapper - calls the direct function.""" + return process_xc_category_direct(account_id, batch, groups, hash_keys) + + +def process_xc_category_direct(account_id, batch, groups, hash_keys): + from django.db import connections + + # Ensure clean database connections for threading + connections.close_all() + account = M3UAccount.objects.get(id=account_id) streams_to_create = [] @@ -394,21 +404,21 @@ def process_xc_category(account_id, batch, groups, hash_keys): # Process all found streams existing_streams = { s.stream_hash: s - for s in Stream.objects.filter(stream_hash__in=stream_hashes.keys()) + for s in Stream.objects.filter(stream_hash__in=stream_hashes.keys()).select_related('m3u_account').only( + 'id', 'stream_hash', 'name', 'url', 'logo_url', 'tvg_id', 'custom_properties', 'last_seen', 'updated_at', 'm3u_account' + ) } for stream_hash, stream_props in stream_hashes.items(): if stream_hash in existing_streams: obj = existing_streams[stream_hash] - existing_attr = { - field.name: getattr(obj, field.name) - for field in Stream._meta.fields - if field != "channel_group_id" - } - changed = any( - existing_attr[key] != value - for key, value in stream_props.items() - if key != "channel_group_id" + # Optimized field comparison for XC streams + changed = ( + obj.name != stream_props["name"] or + obj.url != stream_props["url"] or + obj.logo_url != stream_props["logo_url"] or + obj.tvg_id != stream_props["tvg_id"] or + obj.custom_properties != stream_props["custom_properties"] ) if changed: @@ -478,17 +488,32 @@ def process_xc_category(account_id, batch, groups, hash_keys): except Exception as e: logger.error(f"XC category processing error: {str(e)}") retval = f"Error processing XC batch: {str(e)}" + finally: + # Clean up database connections for threading + connections.close_all() # Aggressive garbage collection del streams_to_create, streams_to_update, stream_hashes, existing_streams gc.collect() + # Clean up database connections for threading + connections.close_all() + return retval -@shared_task def process_m3u_batch(account_id, batch, groups, hash_keys): - """Processes a batch of M3U streams using bulk operations.""" + """Legacy Celery task wrapper - calls the direct function.""" + return process_m3u_batch_direct(account_id, batch, groups, hash_keys) + + +def process_m3u_batch_direct(account_id, batch, groups, hash_keys): + """Processes a batch of M3U streams using bulk operations with thread-safe DB connections.""" + from django.db import connections + + # Ensure clean database connections for threading + connections.close_all() + account = M3UAccount.objects.get(id=account_id) compiled_filters = [ @@ -571,84 +596,56 @@ def process_m3u_batch(account_id, batch, groups, hash_keys): existing_streams = { s.stream_hash: s - for s in Stream.objects.filter(stream_hash__in=stream_hashes.keys()) + for s in Stream.objects.filter(stream_hash__in=stream_hashes.keys()).select_related('m3u_account').only( + 'id', 'stream_hash', 'name', 'url', 'logo_url', 'tvg_id', 'custom_properties', 'last_seen', 'updated_at', 'm3u_account' + ) } for stream_hash, stream_props in stream_hashes.items(): if stream_hash in existing_streams: obj = existing_streams[stream_hash] - existing_attr = { - field.name: getattr(obj, field.name) - for field in Stream._meta.fields - if field != "channel_group_id" - } - changed = any( - existing_attr[key] != value - for key, value in stream_props.items() - if key != "channel_group_id" + # Optimized field comparison + changed = ( + obj.name != stream_props["name"] or + obj.url != stream_props["url"] or + obj.logo_url != stream_props["logo_url"] or + obj.tvg_id != stream_props["tvg_id"] or + obj.custom_properties != stream_props["custom_properties"] ) + # Always update last_seen + obj.last_seen = timezone.now() + if changed: - for key, value in stream_props.items(): - setattr(obj, key, value) - obj.last_seen = timezone.now() - obj.updated_at = ( - timezone.now() - ) # Update timestamp only for changed streams - streams_to_update.append(obj) - del existing_streams[stream_hash] - else: - # Always update last_seen, even if nothing else changed - obj.last_seen = timezone.now() - # Don't update updated_at for unchanged streams - streams_to_update.append(obj) - existing_streams[stream_hash] = obj + # Only update fields that changed and set updated_at + obj.name = stream_props["name"] + obj.url = stream_props["url"] + obj.logo_url = stream_props["logo_url"] + obj.tvg_id = stream_props["tvg_id"] + obj.custom_properties = stream_props["custom_properties"] + obj.updated_at = timezone.now() + + streams_to_update.append(obj) else: + # New stream stream_props["last_seen"] = timezone.now() - stream_props["updated_at"] = ( - timezone.now() - ) # Set initial updated_at for new streams + stream_props["updated_at"] = timezone.now() streams_to_create.append(Stream(**stream_props)) try: with transaction.atomic(): if streams_to_create: Stream.objects.bulk_create(streams_to_create, ignore_conflicts=True) + if streams_to_update: - # We need to split the bulk update to correctly handle updated_at - # First, get the subset of streams that have content changes - changed_streams = [ - s - for s in streams_to_update - if hasattr(s, "updated_at") and s.updated_at - ] - unchanged_streams = [ - s - for s in streams_to_update - if not hasattr(s, "updated_at") or not s.updated_at - ] - - # Update changed streams with all fields including updated_at - if changed_streams: - Stream.objects.bulk_update( - changed_streams, - { - key - for key in stream_props.keys() - if key not in ["m3u_account", "stream_hash"] - and key not in hash_keys - } - | {"last_seen", "updated_at"}, - ) - - # Update unchanged streams with only last_seen - if unchanged_streams: - Stream.objects.bulk_update(unchanged_streams, ["last_seen"]) - - if len(existing_streams.keys()) > 0: - Stream.objects.bulk_update(existing_streams.values(), ["last_seen"]) + # Update all streams in a single bulk operation + Stream.objects.bulk_update( + streams_to_update, + ['name', 'url', 'logo_url', 'tvg_id', 'custom_properties', 'last_seen', 'updated_at'], + batch_size=200 + ) except Exception as e: - logger.error(f"Bulk create failed: {str(e)}") + logger.error(f"Bulk operation failed: {str(e)}") retval = f"M3U account: {account_id}, Batch processed: {len(streams_to_create)} created, {len(streams_to_update)} updated." @@ -657,6 +654,9 @@ def process_m3u_batch(account_id, batch, groups, hash_keys): # from core.utils import cleanup_memory # cleanup_memory(log_usage=True, force_collection=True) + # Clean up database connections for threading + connections.close_all() + return retval @@ -1764,19 +1764,86 @@ def refresh_single_m3u_account(account_id): account.status = M3UAccount.Status.PARSING account.save(update_fields=["status"]) + # Commit any pending transactions before threading + from django.db import transaction + transaction.commit() + + # Initialize stream counters + streams_created = 0 + streams_updated = 0 + if account.account_type == M3UAccount.Types.STADNARD: logger.debug( f"Processing Standard account ({account_id}) with groups: {existing_groups}" ) - # Break into batches and process in parallel + # Break into batches and process with threading - use global batch size batches = [ extinf_data[i : i + BATCH_SIZE] for i in range(0, len(extinf_data), BATCH_SIZE) ] - task_group = group( - process_m3u_batch.s(account_id, batch, existing_groups, hash_keys) - for batch in batches - ) + + logger.info(f"Processing {len(extinf_data)} streams in {len(batches)} thread batches") + + # Use 2 threads for optimal database connection handling + max_workers = min(2, len(batches)) + logger.debug(f"Using {max_workers} threads for processing") + + with ThreadPoolExecutor(max_workers=max_workers) as executor: + # Submit batch processing tasks using direct functions (now thread-safe) + future_to_batch = { + executor.submit(process_m3u_batch_direct, account_id, batch, existing_groups, hash_keys): i + for i, batch in enumerate(batches) + } + + completed_batches = 0 + total_batches = len(batches) + + # Process completed batches as they finish + for future in as_completed(future_to_batch): + batch_idx = future_to_batch[future] + try: + result = future.result() + completed_batches += 1 + + # Extract stream counts from result + if isinstance(result, str): + try: + created_match = re.search(r"(\d+) created", result) + updated_match = re.search(r"(\d+) updated", result) + if created_match and updated_match: + created_count = int(created_match.group(1)) + updated_count = int(updated_match.group(1)) + streams_created += created_count + streams_updated += updated_count + except (AttributeError, ValueError): + pass + + # Send progress update + progress = int((completed_batches / total_batches) * 100) + current_elapsed = time.time() - start_time + + if progress > 0: + estimated_total = (current_elapsed / progress) * 100 + time_remaining = max(0, estimated_total - current_elapsed) + else: + time_remaining = 0 + + send_m3u_update( + account_id, + "parsing", + progress, + elapsed_time=current_elapsed, + time_remaining=time_remaining, + streams_processed=streams_created + streams_updated, + ) + + logger.debug(f"Thread batch {completed_batches}/{total_batches} completed") + + except Exception as e: + logger.error(f"Error in thread batch {batch_idx}: {str(e)}") + completed_batches += 1 # Still count it to avoid hanging + + logger.info(f"Thread-based processing completed for account {account_id}") else: # For XC accounts, get the groups with their custom properties containing xc_id logger.debug(f"Processing XC account with groups: {existing_groups}") @@ -1819,92 +1886,80 @@ def refresh_single_m3u_account(account_id): f"Filtered {len(filtered_groups)} groups for processing: {filtered_groups}" ) - # Batch the groups + # Batch the groups - use reasonable group batch size for XC processing + GROUP_BATCH_SIZE = 2 # Process 2 groups per batch for XC filtered_groups_list = list(filtered_groups.items()) batches = [ - dict(filtered_groups_list[i : i + 2]) - for i in range(0, len(filtered_groups_list), 2) + dict(filtered_groups_list[i : i + GROUP_BATCH_SIZE]) + for i in range(0, len(filtered_groups_list), GROUP_BATCH_SIZE) ] logger.info(f"Created {len(batches)} batches for XC processing") - task_group = group( - process_xc_category.s(account_id, batch, existing_groups, hash_keys) - for batch in batches - ) - total_batches = len(batches) - completed_batches = 0 - streams_processed = 0 # Track total streams processed - logger.debug( - f"Dispatched {len(batches)} parallel tasks for account_id={account_id}." - ) + # Use threading for XC processing instead of Celery group + max_workers = min(2, len(batches)) + logger.debug(f"Using {max_workers} threads for XC processing") - # result = task_group.apply_async() - result = task_group.apply_async() + with ThreadPoolExecutor(max_workers=max_workers) as executor: + # Submit XC batch processing tasks using direct functions (now thread-safe) + future_to_batch = { + executor.submit(process_xc_category_direct, account_id, batch, existing_groups, hash_keys): i + for i, batch in enumerate(batches) + } - # Wait for all tasks to complete and collect their result IDs - completed_task_ids = set() - while completed_batches < total_batches: - for async_result in result: - if ( - async_result.ready() and async_result.id not in completed_task_ids - ): # If the task has completed and we haven't counted it - task_result = async_result.result # The result of the task - logger.debug(f"Task completed with result: {task_result}") + completed_batches = 0 + total_batches = len(batches) - # Extract stream counts from result string if available - if isinstance(task_result, str): - try: - created_match = re.search(r"(\d+) created", task_result) - updated_match = re.search(r"(\d+) updated", task_result) + # Process completed batches as they finish + for future in as_completed(future_to_batch): + batch_idx = future_to_batch[future] + try: + result = future.result() + completed_batches += 1 - if created_match and updated_match: - created_count = int(created_match.group(1)) - updated_count = int(updated_match.group(1)) - streams_processed += created_count + updated_count - streams_created += created_count - streams_updated += updated_count - except (AttributeError, ValueError): - pass + # Extract stream counts from result + if isinstance(result, str): + try: + created_match = re.search(r"(\d+) created", result) + updated_match = re.search(r"(\d+) updated", result) + if created_match and updated_match: + created_count = int(created_match.group(1)) + updated_count = int(updated_match.group(1)) + streams_created += created_count + streams_updated += updated_count + except (AttributeError, ValueError): + pass - completed_batches += 1 - completed_task_ids.add( - async_result.id - ) # Mark this task as processed + # Send progress update + progress = int((completed_batches / total_batches) * 100) + current_elapsed = time.time() - start_time - # Calculate progress - progress = int((completed_batches / total_batches) * 100) + if progress > 0: + estimated_total = (current_elapsed / progress) * 100 + time_remaining = max(0, estimated_total - current_elapsed) + else: + time_remaining = 0 - # Calculate elapsed time and estimated remaining time - current_elapsed = time.time() - start_time - if progress > 0: - estimated_total = (current_elapsed / progress) * 100 - time_remaining = max(0, estimated_total - current_elapsed) - else: - time_remaining = 0 + send_m3u_update( + account_id, + "parsing", + progress, + elapsed_time=current_elapsed, + time_remaining=time_remaining, + streams_processed=streams_created + streams_updated, + ) - # Send progress update via Channels - # Don't send 100% because we want to clean up after - if progress == 100: - progress = 99 + logger.debug(f"XC thread batch {completed_batches}/{total_batches} completed") - send_m3u_update( - account_id, - "parsing", - progress, - elapsed_time=current_elapsed, - time_remaining=time_remaining, - streams_processed=streams_processed, - ) + except Exception as e: + logger.error(f"Error in XC thread batch {batch_idx}: {str(e)}") + completed_batches += 1 # Still count it to avoid hanging - # Optionally remove completed task from the group to prevent processing it again - result.remove(async_result) - else: - logger.trace(f"Task is still running.") + logger.info(f"XC thread-based processing completed for account {account_id}") # Ensure all database transactions are committed before cleanup logger.info( - f"All {total_batches} tasks completed, ensuring DB transactions are committed before cleanup" + f"All thread processing completed, ensuring DB transactions are committed before cleanup" ) # Force a simple DB query to ensure connection sync Stream.objects.filter( @@ -1933,6 +1988,9 @@ def refresh_single_m3u_account(account_id): # Calculate elapsed time elapsed_time = time.time() - start_time + # Calculate total streams processed + streams_processed = streams_created + streams_updated + # Set status to success and update timestamp BEFORE sending the final update account.status = M3UAccount.Status.SUCCESS account.last_message = (