diff --git a/decitala/fragment.py b/decitala/fragment.py index 245f5cbe..17f7e7f4 100644 --- a/decitala/fragment.py +++ b/decitala/fragment.py @@ -701,6 +701,23 @@ def __repr__(self): # class TheorieKarnatique(GeneralFragment): # pass +class Segment(GeneralFragment): + """ + Class that represents a series of equal-duration tones. Used for clustering + in path-finding. This isn't really useful for general use –– better to use the + :obj:`fragment.GeneralFragment` class. + """ + frag_type = "segment" + + def __init__(self, ql, count, **kwargs): + ql_array = [ql] * count + super().__init__(data=ql_array, name=f"Segment-{count}") + self.ql = ql + self.count = count + + def __repr__(self): + return f"" + #################################################################################################### # Some simple queries for quick access. def get_all_greek_feet(): diff --git a/decitala/search.py b/decitala/search.py index b42cc087..5ece9381 100644 --- a/decitala/search.py +++ b/decitala/search.py @@ -14,6 +14,7 @@ import numpy as np from dataclasses import dataclass +from itertools import groupby from .utils import ( successive_ratio_array, @@ -26,7 +27,8 @@ ) from .fragment import ( FragmentEncoder, - GeneralFragment + GeneralFragment, + Segment ) from .hash_table import ( FragmentHashTable @@ -276,7 +278,9 @@ def rolling_hash_search( table, windows=list(range(2, 19)), allow_subdivision=False, - allow_contiguous_summation=False + allow_contiguous_summation=False, + min_segment_length=None, + sbc=0 ): """ Function for searching a score for rhythmic fragments and modifications of rhythmic fragments. @@ -288,12 +292,18 @@ def rolling_hash_search( object or one of its subclasses. :param list windows: The allowed window sizes for search. Default is all integers in range 2-19. :param bool allow_subdivision: Whether to check for subdivisions of a frame in the search. + :param bool min_segment_length: the minimal length of to use :obj:`fragment.Segment` fragments. + Default is ``None``, i.e. segments are not used. + :param int sbc: If ``min_segment_length`` is not ``None``, set the ``sbc`` ('segment boundary + check') to search for fragments using edges (of length ``sbc``) in the search. """ object_list = get_object_indices(filepath=filepath, part_num=part_num, ignore_grace=True) if type(table) == FragmentHashTable: # sensitive to inheritance. table.load() + # TODO: can't this be done in the search itself? If the window is too big, stop. + # that way you don't have to doo all of the O(n)... max_dataset_length = len(max(table.data, key=lambda x: len(x))) max_window_size = min(max_dataset_length, len(object_list)) closest_window = min(windows, key=lambda x: abs(x - max_window_size)) @@ -303,7 +313,18 @@ def rolling_hash_search( fragment_id = 0 fragments_found = [] for this_win in windows: - frames = roll_window(array=object_list, window_size=this_win) + if min_segment_length: + frames = [] + grouped = groupby(object_list, key=lambda x: x.quarterLength) + groups = [list(x) for x in grouped] + for elem in groups: + if len(elem) >= min_segment_length: + first_val = elem[0][0].quarterLength + frames.append(Segment(ql=first_val, count=len(elem))) + + frames = roll_window(array=object_list, window_size=this_win, fn=lambda x: type(x) != Segment) + else: + frames = roll_window(array=object_list, window_size=this_win) for this_frame in frames: frame_ql_array = frame_to_ql_array(this_frame) if len(frame_ql_array) < 2: @@ -390,6 +411,7 @@ def path_finder( split_dict=None, slur_constraint=False, enforce_earliest_start=False, + min_segment_length=None, save_filepath=None, verbose=False ): @@ -409,6 +431,8 @@ def path_finder( Default is ``"dijkstra"``. :param bool slur_constraint: Whether to force slurred fragments to appear in the final path. Only possible if `algorithm="floyd-warshall"`. + :param bool min_segment_length: the minimal length of to use :obj:`fragment.Segment` fragments. + Default is ``None``, i.e. segments are not used. :param str save_filepath: An optional path to a JSON file for saving search results. This file can then be loaded with the :meth:`decitala.utils.loader`. :param bool verbose: Whether to log messages. Default is ``False``. @@ -419,7 +443,8 @@ def path_finder( table=table, windows=windows, allow_subdivision=allow_subdivision, - allow_contiguous_summation=allow_contiguous_summation + allow_contiguous_summation=allow_contiguous_summation, + min_segment_length=min_segment_length ) if not extractions: return None diff --git a/tests/test_fragment.py b/tests/test_fragment.py index 73b1f8a7..75e628e1 100644 --- a/tests/test_fragment.py +++ b/tests/test_fragment.py @@ -16,7 +16,8 @@ FragmentEncoder, FragmentDecoder, get_all_prosodic_meters, - prosodic_meter_query + prosodic_meter_query, + Segment ) from decitala.utils import flatten @@ -183,4 +184,10 @@ def test_prosodic_meter_query_unordered(): ProsodicMeter("Cretic_Tetrameter_3", origin="latin"), ProsodicMeter("Cretic_Tetrameter_5", origin="latin"), ProsodicMeter("Cretic_Tetrameter_6", origin="latin"), - ] \ No newline at end of file + ] + +def test_segment(): + s1 = Segment(ql=0.25, count=138) + assert list(s1.ql_array()) == [s1.ql] * s1.count + assert s1.ql == 0.25 + assert s1.count == 138 \ No newline at end of file